forked from celestiaorg/smt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
smt.go
457 lines (398 loc) · 14 KB
/
smt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
// Package smt implements a Sparse Merkle tree.
package smt
import (
"bytes"
"errors"
"hash"
)
const (
right = 1
)
var defaultValue = []byte{}
var errKeyAlreadyEmpty = errors.New("key already empty")
// SparseMerkleTree is a Sparse Merkle tree.
type SparseMerkleTree struct {
th treeHasher
nodes, values MapStore
root []byte
}
// NewSparseMerkleTree creates a new Sparse Merkle tree on an empty MapStore.
func NewSparseMerkleTree(nodes, values MapStore, hasher hash.Hash, options ...Option) *SparseMerkleTree {
smt := SparseMerkleTree{
th: *newTreeHasher(hasher),
nodes: nodes,
values: values,
}
for _, option := range options {
option(&smt)
}
smt.SetRoot(smt.th.placeholder())
return &smt
}
// ImportSparseMerkleTree imports a Sparse Merkle tree from a non-empty MapStore.
func ImportSparseMerkleTree(nodes, values MapStore, hasher hash.Hash, root []byte) *SparseMerkleTree {
smt := SparseMerkleTree{
th: *newTreeHasher(hasher),
nodes: nodes,
values: values,
root: root,
}
return &smt
}
// Root gets the root of the tree.
func (smt *SparseMerkleTree) Root() []byte {
return smt.root
}
// SetRoot sets the root of the tree.
func (smt *SparseMerkleTree) SetRoot(root []byte) {
smt.root = root
}
func (smt *SparseMerkleTree) depth() int {
return smt.th.pathSize() * 8
}
// Get gets the value of a key from the tree.
func (smt *SparseMerkleTree) Get(key []byte) ([]byte, error) {
// Get tree's root
root := smt.Root()
if bytes.Equal(root, smt.th.placeholder()) {
// The tree is empty, return the default value.
return defaultValue, nil
}
path := smt.th.path(key)
value, err := smt.values.Get(path)
if err != nil {
var invalidKeyError *InvalidKeyError
if errors.As(err, &invalidKeyError) {
// If key isn't found, return default value
return defaultValue, nil
} else {
// Otherwise percolate up any other error
return nil, err
}
}
return value, nil
}
// Has returns true if the value at the given key is non-default, false
// otherwise.
func (smt *SparseMerkleTree) Has(key []byte) (bool, error) {
val, err := smt.Get(key)
return !bytes.Equal(defaultValue, val), err
}
// Update sets a new value for a key in the tree, and sets and returns the new root of the tree.
func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) {
newRoot, err := smt.UpdateForRoot(key, value, smt.Root())
if err != nil {
return nil, err
}
smt.SetRoot(newRoot)
return newRoot, nil
}
// Delete deletes a value from tree. It returns the new root of the tree.
func (smt *SparseMerkleTree) Delete(key []byte) ([]byte, error) {
return smt.Update(key, defaultValue)
}
// UpdateForRoot sets a new value for a key in the tree at a specific root, and returns the new root.
func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte) ([]byte, error) {
path := smt.th.path(key)
sideNodes, pathNodes, oldLeafData, _, err := smt.sideNodesForRoot(path, root, false)
if err != nil {
return nil, err
}
var newRoot []byte
if bytes.Equal(value, defaultValue) {
// Delete operation.
newRoot, err = smt.deleteWithSideNodes(path, sideNodes, pathNodes, oldLeafData)
if errors.Is(err, errKeyAlreadyEmpty) {
// This key is already empty; return the old root.
return root, nil
}
if err := smt.values.Delete(path); err != nil {
return nil, err
}
} else {
// Insert or update operation.
newRoot, err = smt.updateWithSideNodes(path, value, sideNodes, pathNodes, oldLeafData)
}
return newRoot, err
}
// DeleteForRoot deletes a value from tree at a specific root. It returns the new root of the tree.
func (smt *SparseMerkleTree) DeleteForRoot(key, root []byte) ([]byte, error) {
return smt.UpdateForRoot(key, defaultValue, root)
}
func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte, pathNodes [][]byte, oldLeafData []byte) ([]byte, error) {
if bytes.Equal(pathNodes[0], smt.th.placeholder()) {
// This key is already empty as it is a placeholder; return an error.
return nil, errKeyAlreadyEmpty
}
actualPath, _ := smt.th.parseLeaf(oldLeafData)
if !bytes.Equal(path, actualPath) {
// This key is already empty as a different key was found its place; return an error.
return nil, errKeyAlreadyEmpty
}
// All nodes above the deleted leaf are now orphaned
for _, node := range pathNodes {
if err := smt.nodes.Delete(node); err != nil {
return nil, err
}
}
var currentHash, currentData []byte
nonPlaceholderReached := false
for i, sideNode := range sideNodes {
if currentData == nil {
sideNodeValue, err := smt.nodes.Get(sideNode)
if err != nil {
return nil, err
}
if smt.th.isLeaf(sideNodeValue) {
// This is the leaf sibling that needs to be bubbled up the tree.
currentHash = sideNode
currentData = sideNode
continue
} else {
// This is the node sibling that needs to be left in its place.
currentData = smt.th.placeholder()
nonPlaceholderReached = true
}
}
if !nonPlaceholderReached && bytes.Equal(sideNode, smt.th.placeholder()) {
// We found another placeholder sibling node, keep going up the
// tree until we find the first sibling that is not a placeholder.
continue
} else if !nonPlaceholderReached {
// We found the first sibling node that is not a placeholder, it is
// time to insert our leaf sibling node here.
nonPlaceholderReached = true
}
if getBitAtFromMSB(path, len(sideNodes)-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
}
if err := smt.nodes.Set(currentHash, currentData); err != nil {
return nil, err
}
currentData = currentHash
}
if currentHash == nil {
// The tree is empty; return placeholder value as root.
currentHash = smt.th.placeholder()
}
return currentHash, nil
}
func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, sideNodes [][]byte, pathNodes [][]byte, oldLeafData []byte) ([]byte, error) {
valueHash := smt.th.digest(value)
currentHash, currentData := smt.th.digestLeaf(path, valueHash)
if err := smt.nodes.Set(currentHash, currentData); err != nil {
return nil, err
}
currentData = currentHash
// If the leaf node that sibling nodes lead to has a different actual path
// than the leaf node being updated, we need to create an intermediate node
// with this leaf node and the new leaf node as children.
//
// First, get the number of bits that the paths of the two leaf nodes share
// in common as a prefix.
var commonPrefixCount int
var oldValueHash []byte
if bytes.Equal(pathNodes[0], smt.th.placeholder()) {
commonPrefixCount = smt.depth()
} else {
var actualPath []byte
actualPath, oldValueHash = smt.th.parseLeaf(oldLeafData)
commonPrefixCount = countCommonPrefix(path, actualPath)
}
if commonPrefixCount != smt.depth() {
if getBitAtFromMSB(path, commonPrefixCount) == right {
currentHash, currentData = smt.th.digestNode(pathNodes[0], currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, pathNodes[0])
}
err := smt.nodes.Set(currentHash, currentData)
if err != nil {
return nil, err
}
currentData = currentHash
} else if oldValueHash != nil {
// Short-circuit if the same value is being set
if bytes.Equal(oldValueHash, valueHash) {
return smt.root, nil
}
// If an old leaf exists, remove it
if err := smt.nodes.Delete(pathNodes[0]); err != nil {
return nil, err
}
if err := smt.values.Delete(path); err != nil {
return nil, err
}
}
// All remaining path nodes are orphaned
for i := 1; i < len(pathNodes); i++ {
if err := smt.nodes.Delete(pathNodes[i]); err != nil {
return nil, err
}
}
// The offset from the bottom of the tree to the start of the side nodes.
// Note: i-offsetOfSideNodes is the index into sideNodes[]
offsetOfSideNodes := smt.depth() - len(sideNodes)
for i := 0; i < smt.depth(); i++ {
var sideNode []byte
if i-offsetOfSideNodes < 0 || sideNodes[i-offsetOfSideNodes] == nil {
if commonPrefixCount != smt.depth() && commonPrefixCount > smt.depth()-1-i {
// If there are no sidenodes at this height, but the number of
// bits that the paths of the two leaf nodes share in common is
// greater than this depth, then we need to build up the tree
// to this depth with placeholder values at siblings.
sideNode = smt.th.placeholder()
} else {
continue
}
} else {
sideNode = sideNodes[i-offsetOfSideNodes]
}
if getBitAtFromMSB(path, smt.depth()-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
}
err := smt.nodes.Set(currentHash, currentData)
if err != nil {
return nil, err
}
currentData = currentHash
}
if err := smt.values.Set(path, value); err != nil {
return nil, err
}
return currentHash, nil
}
// Get all the sibling nodes (sidenodes) for a given path from a given root.
// Returns an array of sibling nodes, the leaf hash found at that path, the
// leaf data, and the sibling data.
//
// If the leaf is a placeholder, the leaf data is nil.
func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte, getSiblingData bool) ([][]byte, [][]byte, []byte, []byte, error) {
// Side nodes for the path. Nodes are inserted in reverse order, then the
// slice is reversed at the end.
sideNodes := make([][]byte, 0, smt.depth())
pathNodes := make([][]byte, 0, smt.depth()+1)
pathNodes = append(pathNodes, root)
if bytes.Equal(root, smt.th.placeholder()) {
// If the root is a placeholder, there are no sidenodes to return.
// Let the "actual path" be the input path.
return sideNodes, pathNodes, nil, nil, nil
}
currentData, err := smt.nodes.Get(root)
if err != nil {
return nil, nil, nil, nil, err
} else if smt.th.isLeaf(currentData) {
// If the root is a leaf, there are also no sidenodes to return.
return sideNodes, pathNodes, currentData, nil, nil
}
var nodeHash []byte
var sideNode []byte
var siblingData []byte
for i := 0; i < smt.depth(); i++ {
leftNode, rightNode := smt.th.parseNode(currentData)
// Get sidenode depending on whether the path bit is on or off.
if getBitAtFromMSB(path, i) == right {
sideNode = leftNode
nodeHash = rightNode
} else {
sideNode = rightNode
nodeHash = leftNode
}
sideNodes = append(sideNodes, sideNode)
pathNodes = append(pathNodes, nodeHash)
if bytes.Equal(nodeHash, smt.th.placeholder()) {
// If the node is a placeholder, we've reached the end.
currentData = nil
break
}
currentData, err = smt.nodes.Get(nodeHash)
if err != nil {
return nil, nil, nil, nil, err
} else if smt.th.isLeaf(currentData) {
// If the node is a leaf, we've reached the end.
break
}
}
if getSiblingData {
siblingData, err = smt.nodes.Get(sideNode)
if err != nil {
return nil, nil, nil, nil, err
}
}
return reverseByteSlices(sideNodes), reverseByteSlices(pathNodes), currentData, siblingData, nil
}
// Prove generates a Merkle proof for a key against the current root.
//
// This proof can be used for read-only applications, but should not be used if
// the leaf may be updated (e.g. in a state transition fraud proof). For
// updatable proofs, see ProveUpdatable.
func (smt *SparseMerkleTree) Prove(key []byte) (SparseMerkleProof, error) {
proof, err := smt.ProveForRoot(key, smt.Root())
return proof, err
}
// ProveForRoot generates a Merkle proof for a key, against a specific node.
// This is primarily useful for generating Merkle proofs for subtrees.
//
// This proof can be used for read-only applications, but should not be used if
// the leaf may be updated (e.g. in a state transition fraud proof). For
// updatable proofs, see ProveUpdatableForRoot.
func (smt *SparseMerkleTree) ProveForRoot(key []byte, root []byte) (SparseMerkleProof, error) {
return smt.doProveForRoot(key, root, false)
}
// ProveUpdatable generates an updatable Merkle proof for a key against the current root.
func (smt *SparseMerkleTree) ProveUpdatable(key []byte) (SparseMerkleProof, error) {
proof, err := smt.ProveUpdatableForRoot(key, smt.Root())
return proof, err
}
// ProveUpdatableForRoot generates an updatable Merkle proof for a key, against a specific node.
// This is primarily useful for generating Merkle proofs for subtrees.
func (smt *SparseMerkleTree) ProveUpdatableForRoot(key []byte, root []byte) (SparseMerkleProof, error) {
return smt.doProveForRoot(key, root, true)
}
func (smt *SparseMerkleTree) doProveForRoot(key []byte, root []byte, isUpdatable bool) (SparseMerkleProof, error) {
path := smt.th.path(key)
sideNodes, pathNodes, leafData, siblingData, err := smt.sideNodesForRoot(path, root, isUpdatable)
if err != nil {
return SparseMerkleProof{}, err
}
var nonEmptySideNodes [][]byte
for _, v := range sideNodes {
if v != nil {
nonEmptySideNodes = append(nonEmptySideNodes, v)
}
}
// Deal with non-membership proofs. If the leaf hash is the placeholder
// value, we do not need to add anything else to the proof.
var nonMembershipLeafData []byte
if !bytes.Equal(pathNodes[0], smt.th.placeholder()) {
actualPath, _ := smt.th.parseLeaf(leafData)
if !bytes.Equal(actualPath, path) {
// This is a non-membership proof that involves showing a different leaf.
// Add the leaf data to the proof.
nonMembershipLeafData = leafData
}
}
proof := SparseMerkleProof{
SideNodes: nonEmptySideNodes,
NonMembershipLeafData: nonMembershipLeafData,
SiblingData: siblingData,
}
return proof, err
}
// ProveCompact generates a compacted Merkle proof for a key against the current root.
func (smt *SparseMerkleTree) ProveCompact(key []byte) (SparseCompactMerkleProof, error) {
proof, err := smt.ProveCompactForRoot(key, smt.Root())
return proof, err
}
// ProveCompactForRoot generates a compacted Merkle proof for a key, at a specific root.
func (smt *SparseMerkleTree) ProveCompactForRoot(key []byte, root []byte) (SparseCompactMerkleProof, error) {
proof, err := smt.ProveForRoot(key, root)
if err != nil {
return SparseCompactMerkleProof{}, err
}
compactedProof, err := CompactProof(proof, smt.th.hasher)
return compactedProof, err
}