Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add necessary trait bounds to balanced merkle tree #5232

1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions base_layer/mmr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tari_utilities = { git = "https://github.com/tari-project/tari_utilities.git", t
tari_crypto = { git = "https://github.com/tari-project/tari-crypto.git", tag = "v0.16.8" }
tari_common = {path = "../../common"}
thiserror = "1.0.26"
borsh = "0.9.3"
digest = "0.9.0"
log = "0.4"
serde = { version = "1.0.97", features = ["derive"] }
Expand Down
70 changes: 45 additions & 25 deletions base_layer/mmr/src/balanced_binary_merkle_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::{collections::HashMap, convert::TryInto, marker::PhantomData};
use std::{
collections::HashMap,
convert::{TryFrom, TryInto},
marker::PhantomData,
};

use borsh::{BorshDeserialize, BorshSerialize};
use digest::Digest;
use serde::{Deserialize, Serialize};
use tari_common::DomainDigest;
use tari_utilities::ByteArray;
use thiserror::Error;

use crate::{common::hash_together, BalancedBinaryMerkleTree, Hash};

#[derive(Debug)]
pub(crate) fn cast_to_u32(value: usize) -> Result<u32, BalancedBinaryMerkleProofError> {
u32::try_from(value).map_err(|_| BalancedBinaryMerkleProofError::MathOverflow)
}

#[derive(BorshDeserialize, BorshSerialize, Deserialize, Serialize, Clone, Debug, Default, PartialEq, Eq)]
pub struct BalancedBinaryMerkleProof<D> {
pub path: Vec<Hash>,
pub node_index: usize,
pub node_index: u32,
_phantom: PhantomData<D>,
}

Expand All @@ -55,7 +65,10 @@ where D: Digest + DomainDigest
&computed_root == root
}

pub fn generate_proof(tree: &BalancedBinaryMerkleTree<D>, leaf_index: usize) -> Self {
pub fn generate_proof(
tree: &BalancedBinaryMerkleTree<D>,
leaf_index: usize,
) -> Result<Self, BalancedBinaryMerkleProofError> {
let mut node_index = tree.get_node_index(leaf_index);
let mut proof = Vec::new();
while node_index > 0 {
Expand All @@ -67,20 +80,22 @@ where D: Digest + DomainDigest
// Traverse to parent
node_index = parent;
}
Self {
Ok(Self {
path: proof,
node_index: tree.get_node_index(leaf_index),
node_index: cast_to_u32(tree.get_node_index(leaf_index))?,
_phantom: PhantomData,
}
})
}
}

#[derive(Debug, Error)]
pub enum MergedBalancedBinaryMerkleProofError {
pub enum BalancedBinaryMerkleProofError {
#[error("Can't merge zero proofs.")]
CantMergeZeroProofs,
#[error("Bad proof semantics")]
BadProofSemantics,
#[error("Math overflow")]
MathOverflow,
jorgeantonio21 marked this conversation as resolved.
Show resolved Hide resolved
}

/// Flag to indicate if proof data represents an index or a node hash
Expand All @@ -94,8 +109,8 @@ pub enum MergedBalancedBinaryMerkleDataType {
#[derive(Debug)]
pub struct MergedBalancedBinaryMerkleProof<D> {
pub paths: Vec<Vec<(MergedBalancedBinaryMerkleDataType, Vec<u8>)>>, // these tuples can contain indexes or hashes!
pub node_indices: Vec<usize>,
pub heights: Vec<usize>,
pub node_indices: Vec<u32>,
pub heights: Vec<u32>,
_phantom: PhantomData<D>,
}

Expand All @@ -104,20 +119,23 @@ where D: Digest + DomainDigest
{
pub fn create_from_proofs(
proofs: Vec<BalancedBinaryMerkleProof<D>>,
) -> Result<Self, MergedBalancedBinaryMerkleProofError> {
let heights = proofs.iter().map(|proof| proof.path.len()).collect::<Vec<_>>();
) -> Result<Self, BalancedBinaryMerkleProofError> {
let heights = proofs
.iter()
.map(|proof| cast_to_u32(proof.path.len()))
.collect::<Result<Vec<_>, _>>()?;
let max_height = heights
.iter()
.max()
.ok_or(MergedBalancedBinaryMerkleProofError::CantMergeZeroProofs)?;
.ok_or(BalancedBinaryMerkleProofError::CantMergeZeroProofs)?;
let mut indices = proofs.iter().map(|proof| proof.node_index).collect::<Vec<_>>();
let mut paths = vec![Vec::new(); proofs.len()];
let mut join_indices = vec![None; proofs.len()];
for height in (0..*max_height).rev() {
let mut hash_map = HashMap::new();
for (index, proof) in proofs.iter().enumerate() {
// If this path was already joined ignore it.
if join_indices[index].is_none() && proof.path.len() > height {
if join_indices[index].is_none() && proof.path.len() > height as usize {
let parent = (indices[index] - 1) >> 1;
if let Some(other_proof) = hash_map.insert(parent, index) {
join_indices[index] = Some(other_proof);
Expand All @@ -129,7 +147,7 @@ where D: Digest + DomainDigest
0,
(
MergedBalancedBinaryMerkleDataType::Hash,
proof.path[proof.path.len() - 1 - height].clone(),
proof.path[proof.path.len() - 1 - height as usize].clone(),
),
);
}
Expand All @@ -149,19 +167,19 @@ where D: Digest + DomainDigest
mut self,
root: &Hash,
leaves_hashes: Vec<Hash>,
) -> Result<bool, MergedBalancedBinaryMerkleProofError> {
) -> Result<bool, BalancedBinaryMerkleProofError> {
// Check that the proof and verifier data match
let n = self.node_indices.len(); // number of merged proofs
if self.paths.len() != n || leaves_hashes.len() != n {
return Err(MergedBalancedBinaryMerkleProofError::BadProofSemantics);
return Err(BalancedBinaryMerkleProofError::BadProofSemantics);
}

let mut computed_hashes = leaves_hashes;
let max_height = self
.heights
.iter()
.max()
.ok_or(MergedBalancedBinaryMerkleProofError::CantMergeZeroProofs)?;
.ok_or(BalancedBinaryMerkleProofError::CantMergeZeroProofs)?;

// We need to compute the hashes row by row to be sure they are processed correctly.
for height in (0..*max_height).rev() {
Expand All @@ -177,14 +195,14 @@ where D: Digest + DomainDigest
.1
.as_bytes()
.try_into()
.map_err(|_| MergedBalancedBinaryMerkleProofError::BadProofSemantics)?,
.map_err(|_| BalancedBinaryMerkleProofError::BadProofSemantics)?,
);

// The index must also point to one of the proofs
if index < hashes.len() {
&hashes[index]
} else {
return Err(MergedBalancedBinaryMerkleProofError::BadProofSemantics);
return Err(BalancedBinaryMerkleProofError::BadProofSemantics);
}
},
MergedBalancedBinaryMerkleDataType::Hash => &hash_or_index.1,
Expand Down Expand Up @@ -223,11 +241,11 @@ mod test {
let hash_last = leaves[n - 1].clone();
let bmt = BalancedBinaryMerkleTree::<DomainSeparatedHasher<Blake256, TestDomain>>::create(leaves);
let root = bmt.get_merkle_root();
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, 0).unwrap();
assert!(proof.verify(&root, hash_0));
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n / 2);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n / 2).unwrap();
assert!(proof.verify(&root, hash_n_half));
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n - 1);
let proof = BalancedBinaryMerkleProof::generate_proof(&bmt, n - 1).unwrap();
assert!(proof.verify(&root, hash_last));
}
}
Expand All @@ -241,7 +259,8 @@ mod test {
let proofs = indices
.iter()
.map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, *i))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()
.unwrap();
let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap();
assert!(merged_proof
.verify_consume(&root, indices.iter().map(|i| leaves[*i].clone()).collect::<Vec<_>>())
Expand All @@ -255,7 +274,8 @@ mod test {
let root = bmt.get_merkle_root();
let proofs = (0..255)
.map(|i| BalancedBinaryMerkleProof::generate_proof(&bmt, i))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()
.unwrap();
let merged_proof = MergedBalancedBinaryMerkleProof::create_from_proofs(proofs).unwrap();
assert!(merged_proof.verify_consume(&root, leaves).unwrap());
}
Expand Down
12 changes: 9 additions & 3 deletions base_layer/mmr/src/balanced_binary_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::marker::PhantomData;
use std::{convert::TryFrom, marker::PhantomData};

use digest::Digest;
use tari_common::DomainDigest;
use thiserror::Error;

use crate::{common::hash_together, ArrayLike, Hash};

pub(crate) fn cast_to_u32(value: usize) -> Result<u32, BalancedBinaryMerkleTreeError> {
u32::try_from(value).map_err(|_| BalancedBinaryMerkleTreeError::MathOverFlow)
}

#[derive(Clone, Debug, PartialEq, Eq, Error)]
pub enum BalancedBinaryMerkleTreeError {
#[error("There is no leaf with the hash provided.")]
LeafNotFound,
#[error("Math overflow")]
MathOverFlow,
}

// The hashes are perfectly balanced binary tree, so parent at index `i` (0-based) has children at positions `2*i+1` and
Expand Down Expand Up @@ -92,7 +98,7 @@ where D: Digest + DomainDigest
leaf_index + (self.hashes.len() >> 1)
}

pub fn find_leaf_index_for_hash(&self, hash: &Hash) -> Result<usize, BalancedBinaryMerkleTreeError> {
pub fn find_leaf_index_for_hash(&self, hash: &Hash) -> Result<u32, BalancedBinaryMerkleTreeError> {
let pos = self
.hashes
.position(hash)
Expand All @@ -102,7 +108,7 @@ where D: Digest + DomainDigest
// The hash provided was not for leaf, but for node.
Err(BalancedBinaryMerkleTreeError::LeafNotFound)
} else {
Ok(pos - (self.hashes.len() >> 1))
Ok(cast_to_u32(pos - (self.hashes.len() >> 1))?)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion base_layer/mmr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ pub mod pruned_hashset;
pub use backend::{ArrayLike, ArrayLikeExt};
pub use balanced_binary_merkle_proof::{
BalancedBinaryMerkleProof,
BalancedBinaryMerkleProofError,
MergedBalancedBinaryMerkleProof,
MergedBalancedBinaryMerkleProofError,
};
pub use balanced_binary_merkle_tree::{BalancedBinaryMerkleTree, BalancedBinaryMerkleTreeError};
/// MemBackendVec is a shareable, memory only, vector that can be be used with MmrCache to store checkpoints.
Expand Down