Skip to content

Commit

Permalink
Feat: BMT prove uses position path (#61)
Browse files Browse the repository at this point in the history
* BMT prove uses position path

* Format

* Remove unnecessary storage inserts

* Remove unnecessary mut from Node::create_node

* Remove unnecessary functions

* Update version for breaking change

* Change unnecessary `node_mut()` to `node()` in `join_subtrees()`

* Revert "Update version for breaking change"

This reverts commit db3460510b6861805b5e4a2f3b5eb24d704b10f4.
  • Loading branch information
bvrooman authored Jan 13, 2022
1 parent 4b85dbe commit bba1a28
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 298 deletions.
172 changes: 76 additions & 96 deletions fuel-merkle/src/binary/merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use fuel_storage::Storage;

use crate::binary::{empty_sum, Node};
use crate::common::{Bytes32, Subtree};
use crate::common::{Bytes32, Position, Subtree};

#[derive(Debug, thiserror::Error)]
pub enum MerkleTreeError {
Expand All @@ -12,42 +12,32 @@ pub enum MerkleTreeError {
type ProofSet = Vec<Bytes32>;

pub struct MerkleTree<'storage, StorageError> {
storage: &'storage mut dyn Storage<Bytes32, Node, Error = StorageError>,
storage: &'storage mut dyn Storage<u64, Node, Error = StorageError>,
head: Option<Box<Subtree<Node>>>,
leaves: Vec<Bytes32>,
leaves_count: u64,
}

impl<'storage, StorageError> MerkleTree<'storage, StorageError>
where
StorageError: std::error::Error + 'static,
{
pub fn new(storage: &'storage mut dyn Storage<Bytes32, Node, Error = StorageError>) -> Self {
pub fn new(storage: &'storage mut dyn Storage<u64, Node, Error = StorageError>) -> Self {
Self {
storage,
head: None,
leaves: Vec::<Bytes32>::default(),
leaves_count: 0,
}
}

pub fn root(&mut self) -> Result<Bytes32, Box<dyn std::error::Error>> {
let root = match self.head {
let root_node = self.root_node()?;
let root = match root_node {
None => *empty_sum(),
Some(ref initial) => {
let mut current = initial.clone();
while current.next().is_some() {
let mut head = current;
let mut head_next = head.take_next().unwrap();
current = self.join_subtrees(&mut head_next, &mut head)?
}
current.node().key()
}
Some(ref node) => *node.hash(),
};

Ok(root)
}

pub fn prove(
&mut self,
proof_index: u64,
Expand All @@ -56,25 +46,33 @@ where
return Err(Box::new(MerkleTreeError::InvalidProofIndex(proof_index)));
}

let root = self.root()?;
let mut proof_set = ProofSet::new();

let key = self.leaves[proof_index as usize];
proof_set.push(key);

let mut node = self.storage.get(&key)?.unwrap();
let iter = node.to_mut().proof_iter(self.storage);
for n in iter {
proof_set.push(n.key());
let root_node = self.root_node()?.unwrap();
let root_position = root_node.position();
let leaf_position = Position::from_leaf_index(proof_index);
let leaf_node = self.storage.get(&leaf_position.in_order_index())?.unwrap();
proof_set.push(*leaf_node.hash());

let (_, mut side_positions): (Vec<_>, Vec<_>) = root_position
.path(&leaf_position, self.leaves_count)
.iter()
.unzip();
side_positions.reverse(); // Reorder side positions from leaf to root.
side_positions.pop(); // The last side position is the root; remove it.

for side_position in side_positions {
let key = side_position.in_order_index();
let node = self.storage.get(&key)?.unwrap();
proof_set.push(*node.hash());
}

Ok((root, proof_set))
Ok((self.root()?, proof_set))
}

pub fn push(&mut self, data: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
let node = Node::create_leaf(self.leaves_count, data);
self.storage.insert(&node.key(), &node)?;
self.leaves.push(node.key());

let next = self.head.take();
let head = Box::new(Subtree::<Node>::new(node, next));
Expand All @@ -90,6 +88,23 @@ where
// PRIVATE
//

fn root_node(&mut self) -> Result<Option<Node>, Box<dyn std::error::Error>> {
let root_node = match self.head {
None => None,
Some(ref initial) => {
let mut current = initial.clone();
while current.next().is_some() {
let mut head = current;
let mut head_next = head.take_next().unwrap();
current = self.join_subtrees(&mut head_next, &mut head)?
}
Some(current.node().clone())
}
};

Ok(root_node)
}

fn join_all_subtrees(&mut self) -> Result<(), Box<dyn std::error::Error>> {
loop {
let current = self.head.as_ref().unwrap();
Expand Down Expand Up @@ -117,11 +132,8 @@ where
lhs: &mut Subtree<Node>,
rhs: &mut Subtree<Node>,
) -> Result<Box<Subtree<Node>>, Box<dyn std::error::Error>> {
let joined_node = Node::create_node(lhs.node_mut(), rhs.node_mut());
let joined_node = Node::create_node(lhs.node(), rhs.node());
self.storage.insert(&joined_node.key(), &joined_node)?;
self.storage.insert(&lhs.node().key(), lhs.node())?;
self.storage.insert(&rhs.node().key(), rhs.node())?;

let joined_head = Subtree::new(joined_node, lhs.take_next());
Ok(Box::new(joined_head))
}
Expand All @@ -131,14 +143,14 @@ where
mod test {
use super::{MerkleTree, Storage};
use crate::binary::{empty_sum, leaf_sum, node_sum, Node};
use crate::common::{Bytes32, StorageError, StorageMap};
use crate::common::{StorageError, StorageMap};
use fuel_merkle_test_helpers::TEST_DATA;

type MT<'a> = MerkleTree<'a, StorageError>;

#[test]
fn test_push_builds_internal_tree_structure() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..7]; // 7 leaves
Expand Down Expand Up @@ -168,71 +180,39 @@ mod test {
let leaf_4 = leaf_sum(&data[4]);
let leaf_5 = leaf_sum(&data[5]);
let leaf_6 = leaf_sum(&data[6]);

let node_1 = node_sum(&leaf_0, &leaf_1);
let node_5 = node_sum(&leaf_2, &leaf_3);
let node_3 = node_sum(&node_1, &node_5);
let node_9 = node_sum(&leaf_4, &leaf_5);

let s_leaf_0 = storage_map.get(&leaf_0).unwrap().unwrap();
assert_eq!(s_leaf_0.left_key(), None);
assert_eq!(s_leaf_0.right_key(), None);
assert_eq!(s_leaf_0.parent_key(), Some(node_1.clone()));

let s_leaf_1 = storage_map.get(&leaf_1).unwrap().unwrap();
assert_eq!(s_leaf_1.left_key(), None);
assert_eq!(s_leaf_1.right_key(), None);
assert_eq!(s_leaf_1.parent_key(), Some(node_1.clone()));

let s_leaf_2 = storage_map.get(&leaf_2).unwrap().unwrap();
assert_eq!(s_leaf_2.left_key(), None);
assert_eq!(s_leaf_2.right_key(), None);
assert_eq!(s_leaf_2.parent_key(), Some(node_5.clone()));

let s_leaf_3 = storage_map.get(&leaf_3).unwrap().unwrap();
assert_eq!(s_leaf_3.left_key(), None);
assert_eq!(s_leaf_3.right_key(), None);
assert_eq!(s_leaf_3.parent_key(), Some(node_5.clone()));

let s_leaf_4 = storage_map.get(&leaf_4).unwrap().unwrap();
assert_eq!(s_leaf_4.left_key(), None);
assert_eq!(s_leaf_4.right_key(), None);
assert_eq!(s_leaf_4.parent_key(), Some(node_9.clone()));

let s_leaf_5 = storage_map.get(&leaf_5).unwrap().unwrap();
assert_eq!(s_leaf_5.left_key(), None);
assert_eq!(s_leaf_5.right_key(), None);
assert_eq!(s_leaf_5.parent_key(), Some(node_9.clone()));

let s_leaf_6 = storage_map.get(&leaf_6).unwrap().unwrap();
assert_eq!(s_leaf_6.left_key(), None);
assert_eq!(s_leaf_6.right_key(), None);
assert_eq!(s_leaf_6.parent_key(), None);

let s_node_1 = storage_map.get(&node_1).unwrap().unwrap();
assert_eq!(s_node_1.left_key(), Some(leaf_0.clone()));
assert_eq!(s_node_1.right_key(), Some(leaf_1.clone()));
assert_eq!(s_node_1.parent_key(), Some(node_3.clone()));

let s_node_5 = storage_map.get(&node_5).unwrap().unwrap();
assert_eq!(s_node_5.left_key(), Some(leaf_2.clone()));
assert_eq!(s_node_5.right_key(), Some(leaf_3.clone()));
assert_eq!(s_node_5.parent_key(), Some(node_3.clone()));

let s_node_9 = storage_map.get(&node_9).unwrap().unwrap();
assert_eq!(s_node_9.left_key(), Some(leaf_4.clone()));
assert_eq!(s_node_9.right_key(), Some(leaf_5.clone()));
assert_eq!(s_node_9.parent_key(), None);

let s_node_3 = storage_map.get(&node_3).unwrap().unwrap();
assert_eq!(s_node_3.left_key(), Some(node_1.clone()));
assert_eq!(s_node_3.right_key(), Some(node_5.clone()));
assert_eq!(s_node_3.parent_key(), None);
let s_leaf_0 = storage_map.get(&0).unwrap().unwrap();
let s_leaf_1 = storage_map.get(&2).unwrap().unwrap();
let s_leaf_2 = storage_map.get(&4).unwrap().unwrap();
let s_leaf_3 = storage_map.get(&6).unwrap().unwrap();
let s_leaf_4 = storage_map.get(&8).unwrap().unwrap();
let s_leaf_5 = storage_map.get(&10).unwrap().unwrap();
let s_leaf_6 = storage_map.get(&12).unwrap().unwrap();
let s_node_1 = storage_map.get(&1).unwrap().unwrap();
let s_node_5 = storage_map.get(&5).unwrap().unwrap();
let s_node_9 = storage_map.get(&9).unwrap().unwrap();
let s_node_3 = storage_map.get(&3).unwrap().unwrap();

assert_eq!(s_leaf_0.hash(), &leaf_0);
assert_eq!(s_leaf_1.hash(), &leaf_1);
assert_eq!(s_leaf_2.hash(), &leaf_2);
assert_eq!(s_leaf_3.hash(), &leaf_3);
assert_eq!(s_leaf_4.hash(), &leaf_4);
assert_eq!(s_leaf_5.hash(), &leaf_5);
assert_eq!(s_leaf_6.hash(), &leaf_6);
assert_eq!(s_node_1.hash(), &node_1);
assert_eq!(s_node_5.hash(), &node_5);
assert_eq!(s_node_9.hash(), &node_9);
assert_eq!(s_node_3.hash(), &node_3);
}

#[test]
fn root_returns_the_empty_root_for_0_leaves() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let root = tree.root().unwrap();
Expand All @@ -241,7 +221,7 @@ mod test {

#[test]
fn root_returns_the_merkle_root_for_1_leaf() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..1]; // 1 leaf
Expand All @@ -257,7 +237,7 @@ mod test {

#[test]
fn root_returns_the_merkle_root_for_7_leaves() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..7]; // 7 leaves
Expand Down Expand Up @@ -301,7 +281,7 @@ mod test {

#[test]
fn prove_returns_invalid_proof_index_error_for_0_leaves() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let proof = tree.prove(0);
Expand All @@ -310,7 +290,7 @@ mod test {

#[test]
fn prove_returns_invalid_proof_index_error_when_index_is_greater_than_number_of_leaves() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..5]; // 5 leaves
Expand All @@ -324,7 +304,7 @@ mod test {

#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_1_leaf() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..1]; // 1 leaf
Expand All @@ -346,7 +326,7 @@ mod test {

#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_4_leaves() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..4]; // 4 leaves
Expand Down Expand Up @@ -415,7 +395,7 @@ mod test {

#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_5_leaves() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..5]; // 5 leaves
Expand Down Expand Up @@ -502,7 +482,7 @@ mod test {

#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_7_leaves() {
let mut storage_map = StorageMap::<Bytes32, Node>::new();
let mut storage_map = StorageMap::<u64, Node>::new();
let mut tree = MT::new(&mut storage_map);

let data = &TEST_DATA[0..7]; // 7 leaves
Expand Down
Loading

0 comments on commit bba1a28

Please sign in to comment.