Skip to content

Commit

Permalink
core/state/snapshot: ensure Cap retains a min number of layers
Browse files Browse the repository at this point in the history
  • Loading branch information
karalabe committed Feb 16, 2021
1 parent 7d1b711 commit 9ec3329
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 93 deletions.
47 changes: 16 additions & 31 deletions core/state/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs m
// Cap traverses downwards the snapshot tree from a head block hash until the
// number of allowed layers are crossed. All layers beyond the permitted number
// are flattened downwards.
//
// Note, the final diff layer count in general will be one more than the amount
// requested. This happens because the bottom-most diff layer is the accumulator
// which may or may not overflow and cascade to disk. Since this last layer's
// survival is only known *after* capping, we need to omit it from the count if
// we want to ensure that *at least* the requested number of diff layers remain.
func (t *Tree) Cap(root common.Hash, layers int) error {
// Retrieve the head snapshot to cap from
snap := t.Snapshot(root)
Expand All @@ -324,10 +330,7 @@ func (t *Tree) Cap(root common.Hash, layers int) error {
// Flattening the bottom-most diff layer requires special casing since there's
// no child to rewire to the grandparent. In that case we can fake a temporary
// child for the capping and then remove it.
var persisted *diskLayer

switch layers {
case 0:
if layers == 0 {
// If full commit was requested, flatten the diffs and merge onto disk
diff.lock.RLock()
base := diffToDisk(diff.flatten().(*diffLayer))
Expand All @@ -336,33 +339,9 @@ func (t *Tree) Cap(root common.Hash, layers int) error {
// Replace the entire snapshot tree with the flat base
t.layers = map[common.Hash]snapshot{base.root: base}
return nil

case 1:
// If full flattening was requested, flatten the diffs but only merge if the
// memory limit was reached
var (
bottom *diffLayer
base *diskLayer
)
diff.lock.RLock()
bottom = diff.flatten().(*diffLayer)
if bottom.memory >= aggregatorMemoryLimit {
base = diffToDisk(bottom)
}
diff.lock.RUnlock()

// If all diff layers were removed, replace the entire snapshot tree
if base != nil {
t.layers = map[common.Hash]snapshot{base.root: base}
return nil
}
// Merge the new aggregated layer into the snapshot tree, clean stales below
t.layers[bottom.root] = bottom

default:
// Many layers requested to be retained, cap normally
persisted = t.cap(diff, layers)
}
persisted := t.cap(diff, layers)

// Remove any layer that is stale or links into a stale layer
children := make(map[common.Hash][]common.Hash)
for root, snap := range t.layers {
Expand Down Expand Up @@ -405,9 +384,15 @@ func (t *Tree) Cap(root common.Hash, layers int) error {
// layer limit is reached, memory cap is also enforced (but not before).
//
// The method returns the new disk layer if diffs were persisted into it.
//
// Note, the final diff layer count in general will be one more than the amount
// requested. This happens because the bottom-most diff layer is the accumulator
// which may or may not overflow and cascade to disk. Since this last layer's
// survival is only known *after* capping, we need to omit it from the count if
// we want to ensure that *at least* the requested number of diff layers remain.
func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer {
// Dive until we run out of layers or reach the persistent database
for ; layers > 2; layers-- {
for i := 0; i < layers-1; i++ {
// If we still have diff layers below, continue down
if parent, ok := diff.parent.(*diffLayer); ok {
diff = parent
Expand Down
111 changes: 49 additions & 62 deletions core/state/snapshot/snapshot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,57 +162,10 @@ func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) {
defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit)
aggregatorMemoryLimit = 0

if err := snaps.Cap(common.HexToHash("0x03"), 2); err != nil {
t.Fatalf("failed to merge diff layer onto disk: %v", err)
}
// Since the base layer was modified, ensure that data retrievald on the external reference fail
if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
}
if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale {
t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err)
}
if n := len(snaps.layers); n != 2 {
t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2)
fmt.Println(snaps.layers)
}
}

// Tests that if a diff layer becomes stale, no active external references will
// be returned with junk data. This version of the test flattens every diff layer
// to check internal corner case around the bottom-most memory accumulator.
func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) {
// Create an empty base layer and a snapshot tree out of it
base := &diskLayer{
diskdb: rawdb.NewMemoryDatabase(),
root: common.HexToHash("0x01"),
cache: fastcache.New(1024 * 500),
}
snaps := &Tree{
layers: map[common.Hash]snapshot{
base.root: base,
},
}
// Commit two diffs on top and retrieve a reference to the bottommost
accounts := map[common.Hash][]byte{
common.HexToHash("0xa1"): randomAccount(),
}
if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
t.Fatalf("failed to create a diff layer: %v", err)
}
if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
t.Fatalf("failed to create a diff layer: %v", err)
}
if n := len(snaps.layers); n != 3 {
t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3)
}
ref := snaps.Snapshot(common.HexToHash("0x02"))

// Flatten the diff layer into the bottom accumulator
if err := snaps.Cap(common.HexToHash("0x03"), 1); err != nil {
t.Fatalf("failed to flatten diff layer into accumulator: %v", err)
t.Fatalf("failed to merge accumulator onto disk: %v", err)
}
// Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail
// Since the base layer was modified, ensure that data retrievald on the external reference fail
if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
}
Expand Down Expand Up @@ -267,7 +220,7 @@ func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) {
t.Errorf("layers modified, got %d exp %d", got, exp)
}
// Flatten the diff layer into the bottom accumulator
if err := snaps.Cap(common.HexToHash("0x04"), 2); err != nil {
if err := snaps.Cap(common.HexToHash("0x04"), 1); err != nil {
t.Fatalf("failed to flatten diff layer into accumulator: %v", err)
}
// Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail
Expand Down Expand Up @@ -389,25 +342,24 @@ func TestSnaphots(t *testing.T) {
// Create a starting base layer and a snapshot tree out of it
base := &diskLayer{
diskdb: rawdb.NewMemoryDatabase(),
root: common.HexToHash("0x01"),
root: makeRoot(1),
cache: fastcache.New(1024 * 500),
}
snaps := &Tree{
layers: map[common.Hash]snapshot{
base.root: base,
},
}
// Construct the snapshots with 128 layers
// Construct the snapshots with 129 layers, flattening whatever's above that
var (
last = common.HexToHash("0x01")
head common.Hash
)
// Flush another 128 layers, one diff will be flatten into the parent.
for i := 0; i < 128; i++ {
for i := 0; i < 129; i++ {
head = makeRoot(uint64(i + 2))
snaps.Update(head, last, nil, setAccount(fmt.Sprintf("%d", i+2)), nil)
last = head
snaps.Cap(head, 128) // 129 layers(128 diffs + 1 disk) are allowed, 129th is the persistent layer
snaps.Cap(head, 128) // 130 layers (128 diffs + 1 accumulator + 1 disk)
}
var cases = []struct {
headRoot common.Hash
Expand All @@ -417,22 +369,57 @@ func TestSnaphots(t *testing.T) {
expectBottom common.Hash
}{
{head, 0, false, 0, common.Hash{}},
{head, 64, false, 64, makeRoot(127 + 2 - 63)},
{head, 128, false, 128, makeRoot(2)}, // All diff layers
{head, 129, true, 128, makeRoot(2)}, // All diff layers
{head, 129, false, 129, common.HexToHash("0x01")}, // All diff layers + disk layer
{head, 64, false, 64, makeRoot(129 + 2 - 64)},
{head, 128, false, 128, makeRoot(3)}, // Normal diff layers, no accumulator
{head, 129, true, 129, makeRoot(2)}, // All diff layers, including accumulator
{head, 130, false, 130, makeRoot(1)}, // All diff layers + disk layer
}
for i, c := range cases {
layers := snaps.Snapshots(c.headRoot, c.limit, c.nodisk)
if len(layers) != c.expected {
t.Errorf("non-overflow test %d: returned snapshot layers are mismatched, want %v, got %v", i, c.expected, len(layers))
}
if len(layers) == 0 {
continue
}
bottommost := layers[len(layers)-1]
if bottommost.Root() != c.expectBottom {
t.Errorf("non-overflow test %d: snapshot mismatch, want %v, get %v", i, c.expectBottom, bottommost.Root())
}
}
// Above we've tested the normal capping, which leaves the accumulator live.
// Test that if the bottommost accumulator diff layer overflows the allowed
// memory limit, the snapshot tree gets capped to one less layer.
// Commit the diff layer onto the disk and ensure it's persisted
defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit)
aggregatorMemoryLimit = 0

snaps.Cap(head, 128) // 129 (128 diffs + 1 overflown accumulator + 1 disk)

cases = []struct {
headRoot common.Hash
limit int
nodisk bool
expected int
expectBottom common.Hash
}{
{head, 0, false, 0, common.Hash{}},
{head, 64, false, 64, makeRoot(129 + 2 - 64)},
{head, 128, false, 128, makeRoot(3)}, // All diff layers, accumulator was flattened
{head, 129, true, 128, makeRoot(3)}, // All diff layers, accumulator was flattened
{head, 130, false, 129, makeRoot(2)}, // All diff layers + disk layer
}
for _, c := range cases {
for i, c := range cases {
layers := snaps.Snapshots(c.headRoot, c.limit, c.nodisk)
if len(layers) != c.expected {
t.Fatalf("Returned snapshot layers are mismatched, want %v, got %v", c.expected, len(layers))
t.Errorf("overflow test %d: returned snapshot layers are mismatched, want %v, got %v", i, c.expected, len(layers))
}
if len(layers) == 0 {
continue
}
bottommost := layers[len(layers)-1]
if bottommost.Root() != c.expectBottom {
t.Fatalf("Snapshot mismatch, want %v, get %v", c.expectBottom, bottommost.Root())
t.Errorf("overflow test %d: snapshot mismatch, want %v, get %v", i, c.expectBottom, bottommost.Root())
}
}
}

0 comments on commit 9ec3329

Please sign in to comment.