Skip to content

Commit

Permalink
Add function to update a leaf in a MerkleTree structure (#5453)
Browse files Browse the repository at this point in the history
Co-authored-by: Arr00 <[email protected]>
  • Loading branch information
Amxx and arr00 authored Feb 28, 2025
1 parent 7276774 commit 71bc0f7
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 28 deletions.
5 changes: 5 additions & 0 deletions .changeset/good-zebras-ring.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`MerkleTree`: Add an update function that replaces a previously inserted leaf with a new value, updating the tree root along the way.
8 changes: 8 additions & 0 deletions contracts/mocks/MerkleTreeMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ contract MerkleTreeMock {
bytes32 public root;

event LeafInserted(bytes32 leaf, uint256 index, bytes32 root);
event LeafUpdated(bytes32 oldLeaf, bytes32 newLeaf, uint256 index, bytes32 root);

function setup(uint8 _depth, bytes32 _zero) public {
root = _tree.setup(_depth, _zero);
Expand All @@ -25,6 +26,13 @@ contract MerkleTreeMock {
root = currentRoot;
}

function update(uint256 index, bytes32 oldValue, bytes32 newValue, bytes32[] memory proof) public {
(bytes32 oldRoot, bytes32 newRoot) = _tree.update(index, oldValue, newValue, proof);
if (oldRoot != root) revert MerkleTree.MerkleTreeUpdateInvalidProof();
emit LeafUpdated(oldValue, newValue, index, newRoot);
root = newRoot;
}

function depth() public view returns (uint256) {
return _tree.depth();
}
Expand Down
92 changes: 92 additions & 0 deletions contracts/utils/structs/MerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pragma solidity ^0.8.20;
import {Hashes} from "../cryptography/Hashes.sol";
import {Arrays} from "../Arrays.sol";
import {Panic} from "../Panic.sol";
import {StorageSlot} from "../StorageSlot.sol";

/**
* @dev Library for managing https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures.
Expand All @@ -27,6 +28,12 @@ import {Panic} from "../Panic.sol";
* _Available since v5.1._
*/
library MerkleTree {
/// @dev Error emitted when trying to update a leaf that was not previously pushed.
error MerkleTreeUpdateInvalidIndex(uint256 index, uint256 length);

/// @dev Error emitted when the proof used during an update is invalid (could not reproduce the side).
error MerkleTreeUpdateInvalidProof();

/**
* @dev A complete `bytes32` Merkle tree.
*
Expand Down Expand Up @@ -166,6 +173,91 @@ library MerkleTree {
return (index, currentLevelHash);
}

/**
* @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
* root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
* root is the last known one.
*
* The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
* vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
* all "in flight" updates invalid.
*
* This variant uses {Hashes-commutativeKeccak256} to hash internal nodes. It should only be used on merkle trees
* that were setup using the same (default) hashing function (i.e. by calling
* {xref-MerkleTree-setup-struct-MerkleTree-Bytes32PushTree-uint8-bytes32-}[the default setup] function).
*/
function update(
Bytes32PushTree storage self,
uint256 index,
bytes32 oldValue,
bytes32 newValue,
bytes32[] memory proof
) internal returns (bytes32 oldRoot, bytes32 newRoot) {
return update(self, index, oldValue, newValue, proof, Hashes.commutativeKeccak256);
}

/**
* @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
* root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
* root is the last known one.
*
* The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
* vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
* all "in flight" updates invalid.
*
* This variant uses a custom hashing function to hash internal nodes. It should only be called with the same
* function as the one used during the initial setup of the merkle tree.
*/
function update(
Bytes32PushTree storage self,
uint256 index,
bytes32 oldValue,
bytes32 newValue,
bytes32[] memory proof,
function(bytes32, bytes32) view returns (bytes32) fnHash
) internal returns (bytes32 oldRoot, bytes32 newRoot) {
unchecked {
// Check index range
uint256 length = self._nextLeafIndex;
if (index >= length) revert MerkleTreeUpdateInvalidIndex(index, length);

// Cache read
uint256 treeDepth = depth(self);

// Workaround stack too deep
bytes32[] storage sides = self._sides;

// This cannot overflow because: 0 <= index < length
uint256 lastIndex = length - 1;
uint256 currentIndex = index;
bytes32 currentLevelHashOld = oldValue;
bytes32 currentLevelHashNew = newValue;
for (uint32 i = 0; i < treeDepth; i++) {
bool isLeft = currentIndex % 2 == 0;

lastIndex >>= 1;
currentIndex >>= 1;

if (isLeft && currentIndex == lastIndex) {
StorageSlot.Bytes32Slot storage side = Arrays.unsafeAccess(sides, i);
if (side.value != currentLevelHashOld) revert MerkleTreeUpdateInvalidProof();
side.value = currentLevelHashNew;
}

bytes32 sibling = proof[i];
currentLevelHashOld = fnHash(
isLeft ? currentLevelHashOld : sibling,
isLeft ? sibling : currentLevelHashOld
);
currentLevelHashNew = fnHash(
isLeft ? currentLevelHashNew : sibling,
isLeft ? sibling : currentLevelHashNew
);
}
return (currentLevelHashOld, currentLevelHashNew);
}
}

/**
* @dev Tree's depth (set at initialization)
*/
Expand Down
136 changes: 108 additions & 28 deletions test/utils/structs/MerkleTree.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
const { StandardMerkleTree } = require('@openzeppelin/merkle-tree');

const { generators } = require('../../helpers/random');
const { range } = require('../../helpers/iterate');

const makeTree = (leaves = [ethers.ZeroHash]) =>
const DEPTH = 4; // 16 slots

const makeTree = (leaves = [], length = 2 ** DEPTH, zero = ethers.ZeroHash) =>
StandardMerkleTree.of(
leaves.map(leaf => [leaf]),
[]
.concat(
leaves,
Array.from({ length: length - leaves.length }, () => zero),
)
.map(leaf => [leaf]),
['bytes32'],
{ sortLeaves: false },
);

const hashLeaf = leaf => makeTree().leafHash([leaf]);

const DEPTH = 4n; // 16 slots
const ZERO = hashLeaf(ethers.ZeroHash);
const ZERO = makeTree().leafHash([ethers.ZeroHash]);

async function fixture() {
const mock = await ethers.deployContract('MerkleTreeMock');
Expand All @@ -30,69 +35,144 @@ describe('MerkleTree', function () {
});

it('sets initial values at setup', async function () {
const merkleTree = makeTree(Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash));
const merkleTree = makeTree();

expect(await this.mock.root()).to.equal(merkleTree.root);
expect(await this.mock.depth()).to.equal(DEPTH);
expect(await this.mock.nextLeafIndex()).to.equal(0n);
await expect(this.mock.root()).to.eventually.equal(merkleTree.root);
await expect(this.mock.depth()).to.eventually.equal(DEPTH);
await expect(this.mock.nextLeafIndex()).to.eventually.equal(0n);
});

describe('push', function () {
it('tree is correctly updated', async function () {
const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
it('pushing correctly updates the tree', async function () {
const leaves = [];

// for each leaf slot
for (const i in leaves) {
// generate random leaf and hash it
const hashedLeaf = hashLeaf((leaves[i] = generators.bytes32()));
for (const i in range(2 ** DEPTH)) {
// generate random leaf
leaves.push(generators.bytes32());

// update leaf list and rebuild tree.
// rebuild tree.
const tree = makeTree(leaves);
const hash = tree.leafHash(tree.at(i));

// push value to tree
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, i, tree.root);
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, i, tree.root);

// check tree
expect(await this.mock.root()).to.equal(tree.root);
expect(await this.mock.nextLeafIndex()).to.equal(BigInt(i) + 1n);
await expect(this.mock.root()).to.eventually.equal(tree.root);
await expect(this.mock.nextLeafIndex()).to.eventually.equal(BigInt(i) + 1n);
}
});

it('revert when tree is full', async function () {
it('pushing to a full tree reverts', async function () {
await Promise.all(Array.from({ length: 2 ** Number(DEPTH) }).map(() => this.mock.push(ethers.ZeroHash)));

await expect(this.mock.push(ethers.ZeroHash)).to.be.revertedWithPanic(PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED);
});
});

describe('update', function () {
for (const { leafCount, leafIndex } of range(2 ** DEPTH + 1).flatMap(leafCount =>
range(leafCount).map(leafIndex => ({ leafCount, leafIndex })),
))
it(`updating a leaf correctly updates the tree (leaf #${leafIndex + 1}/${leafCount})`, async function () {
// initial tree
const leaves = Array.from({ length: leafCount }, generators.bytes32);
const oldTree = makeTree(leaves);

// fill tree and verify root
for (const i in leaves) {
await this.mock.push(oldTree.leafHash(oldTree.at(i)));
}
await expect(this.mock.root()).to.eventually.equal(oldTree.root);

// create updated tree
leaves[leafIndex] = generators.bytes32();
const newTree = makeTree(leaves);

const oldLeafHash = oldTree.leafHash(oldTree.at(leafIndex));
const newLeafHash = newTree.leafHash(newTree.at(leafIndex));

// perform update
await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, oldTree.getProof(leafIndex)))
.to.emit(this.mock, 'LeafUpdated')
.withArgs(oldLeafHash, newLeafHash, leafIndex, newTree.root);

// verify updated root
await expect(this.mock.root()).to.eventually.equal(newTree.root);

// if there is still room in the tree, fill it
for (const i of range(leafCount, 2 ** DEPTH)) {
// push new value and rebuild tree
leaves.push(generators.bytes32());
const nextTree = makeTree(leaves);

// push and verify root
await this.mock.push(nextTree.leafHash(nextTree.at(i)));
await expect(this.mock.root()).to.eventually.equal(nextTree.root);
}
});

it('replacing a leaf that was not previously pushed reverts', async function () {
// changing leaf 0 on an empty tree
await expect(this.mock.update(1, ZERO, ZERO, []))
.to.be.revertedWithCustomError(this.mock, 'MerkleTreeUpdateInvalidIndex')
.withArgs(1, 0);
});

it('replacing a leaf using an invalid proof reverts', async function () {
const leafCount = 4;
const leafIndex = 2;

const leaves = Array.from({ length: leafCount }, generators.bytes32);
const tree = makeTree(leaves);

// fill tree and verify root
for (const i in leaves) {
await this.mock.push(tree.leafHash(tree.at(i)));
}
await expect(this.mock.root()).to.eventually.equal(tree.root);

const oldLeafHash = tree.leafHash(tree.at(leafIndex));
const newLeafHash = generators.bytes32();
const proof = tree.getProof(leafIndex);
// invalid proof (tamper)
proof[1] = generators.bytes32();

await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, proof)).to.be.revertedWithCustomError(
this.mock,
'MerkleTreeUpdateInvalidProof',
);
});
});

it('reset', async function () {
// empty tree
const zeroLeaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
const zeroTree = makeTree(zeroLeaves);
const emptyTree = makeTree();

// tree with one element
const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
const hashedLeaf = hashLeaf((leaves[0] = generators.bytes32())); // fill first leaf and hash it
const leaves = [generators.bytes32()];
const tree = makeTree(leaves);
const hash = tree.leafHash(tree.at(0));

// root should be that of a zero tree
expect(await this.mock.root()).to.equal(zeroTree.root);
expect(await this.mock.root()).to.equal(emptyTree.root);
expect(await this.mock.nextLeafIndex()).to.equal(0n);

// push leaf and check root
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);

expect(await this.mock.root()).to.equal(tree.root);
expect(await this.mock.nextLeafIndex()).to.equal(1n);

// reset tree
await this.mock.setup(DEPTH, ZERO);

expect(await this.mock.root()).to.equal(zeroTree.root);
expect(await this.mock.root()).to.equal(emptyTree.root);
expect(await this.mock.nextLeafIndex()).to.equal(0n);

// re-push leaf and check root
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);

expect(await this.mock.root()).to.equal(tree.root);
expect(await this.mock.nextLeafIndex()).to.equal(1n);
Expand Down

0 comments on commit 71bc0f7

Please sign in to comment.