Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

part: Fix bug with short and long key with shared prefix #22

Merged
merged 6 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,26 @@ const (
)

type fuzzObj struct {
id uint64
id string
value uint64
}

func mkID() uint64 {
return 1 + uint64(rand.Int63n(numUniqueIDs))
func mkID() string {
// We use a string hex presentation instead of the raw uint64 so we get
// a wide range of different length keys and different prefixes.
return fmt.Sprintf("%x", 1+uint64(rand.Int63n(numUniqueIDs)))
}

func mkValue() uint64 {
return 1 + uint64(rand.Int63n(numUniqueValues))
}

var idIndex = statedb.Index[fuzzObj, uint64]{
var idIndex = statedb.Index[fuzzObj, string]{
Name: "id",
FromObject: func(obj fuzzObj) index.KeySet {
return index.NewKeySet(index.Uint64(obj.id))
return index.NewKeySet(index.String(obj.id))
},
FromKey: index.Uint64,
FromKey: index.String,
Unique: true,
}

Expand All @@ -83,10 +85,10 @@ var valueIndex = statedb.Index[fuzzObj, uint64]{
}

var (
tableFuzz1 = statedb.MustNewTable[fuzzObj]("fuzz1", idIndex, valueIndex)
tableFuzz2 = statedb.MustNewTable[fuzzObj]("fuzz2", idIndex, valueIndex)
tableFuzz3 = statedb.MustNewTable[fuzzObj]("fuzz3", idIndex, valueIndex)
tableFuzz4 = statedb.MustNewTable[fuzzObj]("fuzz4", idIndex, valueIndex)
tableFuzz1 = statedb.MustNewTable("fuzz1", idIndex, valueIndex)
tableFuzz2 = statedb.MustNewTable("fuzz2", idIndex, valueIndex)
tableFuzz3 = statedb.MustNewTable("fuzz3", idIndex, valueIndex)
tableFuzz4 = statedb.MustNewTable("fuzz4", idIndex, valueIndex)
fuzzTables = []statedb.TableMeta{tableFuzz1, tableFuzz2, tableFuzz3, tableFuzz4}
fuzzMetrics = statedb.NewExpVarMetrics(false)
fuzzDB *statedb.DB
Expand Down Expand Up @@ -123,7 +125,7 @@ func (a *realActionLog) validateTable(txn statedb.ReadTxn, table statedb.Table[f
defer a.Unlock()

// Collapse the log down to objects that are alive at the end.
alive := map[uint64]struct{}{}
alive := map[string]struct{}{}
for _, e := range a.log[table.Name()] {
switch e.act {
case actInsert:
Expand All @@ -139,7 +141,7 @@ func (a *realActionLog) validateTable(txn statedb.ReadTxn, table statedb.Table[f
a.log[table.Name()] = nil

iter, _ := table.All(txn)
actual := map[uint64]struct{}{}
actual := map[string]struct{}{}
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
actual[obj.id] = struct{}{}
}
Expand Down Expand Up @@ -185,13 +187,13 @@ const (
type actionLogEntry struct {
table statedb.Table[fuzzObj]
act int
id uint64
id string
value uint64
}

type tableAndID struct {
table string
id uint64
id string
}

type txnActionLog struct {
Expand All @@ -212,7 +214,7 @@ type action func(ctx actionContext)
func insertAction(ctx actionContext) {
id := mkID()
value := mkValue()
ctx.log.log("%s: Insert %d", ctx.table.Name(), id)
ctx.log.log("%s: Insert %s", ctx.table.Name(), id)
ctx.table.Insert(ctx.txn, fuzzObj{id, value})
e := actionLogEntry{ctx.table, actInsert, id, value}
ctx.actLog.append(e)
Expand All @@ -221,7 +223,7 @@ func insertAction(ctx actionContext) {

func deleteAction(ctx actionContext) {
id := mkID()
ctx.log.log("%s: Delete %d", ctx.table.Name(), id)
ctx.log.log("%s: Delete %s", ctx.table.Name(), id)
ctx.table.Delete(ctx.txn, fuzzObj{id, 0})
e := actionLogEntry{ctx.table, actDelete, id, 0}
ctx.actLog.append(e)
Expand All @@ -236,7 +238,7 @@ func deleteAllAction(ctx actionContext) {
panic(err)
}
ctx.table.DeleteAll(ctx.txn)
ctx.actLog.append(actionLogEntry{ctx.table, actDeleteAll, 0, 0})
ctx.actLog.append(actionLogEntry{ctx.table, actDeleteAll, "", 0})
clear(ctx.txnLog.latest)
}

Expand All @@ -248,7 +250,7 @@ func deleteManyAction(ctx actionContext) {
iter, _ := ctx.table.All(ctx.txn)
n := 0
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
ctx.log.log("%s: DeleteMany %d (%d/%d)", ctx.table.Name(), obj.id, n+1, toDelete)
ctx.log.log("%s: DeleteMany %s (%d/%d)", ctx.table.Name(), obj.id, n+1, toDelete)
_, hadOld, _ := ctx.table.Delete(ctx.txn, obj)
if !hadOld {
panic("expected Delete of a known object to return the old object")
Expand Down Expand Up @@ -312,19 +314,19 @@ func firstAction(ctx actionContext) {
}
}
}
ctx.log.log("%s: First(%d) => rev=%d, ok=%v", ctx.table.Name(), id, rev, ok)
ctx.log.log("%s: First(%s) => rev=%d, ok=%v", ctx.table.Name(), id, rev, ok)
}

func lowerboundAction(ctx actionContext) {
id := mkID()
iter, _ := ctx.table.LowerBound(ctx.txn, idIndex.Query(id))
ctx.log.log("%s: LowerBound(%d) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
ctx.log.log("%s: LowerBound(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
}

func prefixAction(ctx actionContext) {
id := mkID()
iter, _ := ctx.table.Prefix(ctx.txn, idIndex.Query(id))
ctx.log.log("%s: Prefix(%d) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
ctx.log.log("%s: Prefix(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
}

var actions = []action{
Expand Down Expand Up @@ -380,7 +382,7 @@ func trackerWorker(i int, stop <-chan struct{}) {
defer iter.Close()

// Keep track of what state the changes lead us to in order to validate it.
state := map[uint64]*statedb.Change[fuzzObj]{}
state := map[string]*statedb.Change[fuzzObj]{}

var txn statedb.ReadTxn
var prevRev statedb.Revision
Expand All @@ -399,7 +401,7 @@ func trackerWorker(i int, stop <-chan struct{}) {
}
prevRev = rev

if change.Object.id == 0 || change.Object.value == 0 {
if change.Object.id == "" || change.Object.value == 0 {
panic("trackerWorker: object with zero id/value")
}

Expand All @@ -419,7 +421,7 @@ func trackerWorker(i int, stop <-chan struct{}) {
for obj, rev, ok := iterAll.Next(); ok; obj, rev, ok = iterAll.Next() {
change, found := state[obj.id]
if !found {
panic(fmt.Sprintf("trackerWorker: object %d not found from state", obj.id))
panic(fmt.Sprintf("trackerWorker: object %s not found from state", obj.id))
}

if change.Revision != rev {
Expand All @@ -434,7 +436,7 @@ func trackerWorker(i int, stop <-chan struct{}) {

if len(state2) > 0 {
for id := range state2 {
log.log("%d should not exist\n", id)
log.log("%s should not exist\n", id)
}
panic(fmt.Sprintf("trackerWorker: %d orphan object(s)", len(state2)))
}
Expand Down
3 changes: 3 additions & 0 deletions part/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func (n *header[T]) promote(watch bool) *header[T] {
node4 := n.node4()
node16 := &node16[T]{header: *n}
node16.setKind(nodeKind16)
node16.leaf = n.getLeaf()
size := node4.size()
copy(node16.children[:], node4.children[:size])
copy(node16.keys[:], node4.keys[:size])
Expand All @@ -182,6 +183,7 @@ func (n *header[T]) promote(watch bool) *header[T] {
node16 := n.node16()
node48 := &node48[T]{header: *n}
node48.setKind(nodeKind48)
node48.leaf = n.getLeaf()
copy(node48.children[:], node16.children[:node16.size()])
for i, k := range node16.keys[:node16.size()] {
node48.index[k] = int8(i)
Expand All @@ -194,6 +196,7 @@ func (n *header[T]) promote(watch bool) *header[T] {
node48 := n.node48()
node256 := &node256[T]{header: *n}
node256.setKind(nodeKind256)
node256.leaf = n.getLeaf()

// Since Node256 has children indexed directly, iterate over the children
// to assign them to the right index.
Expand Down
112 changes: 103 additions & 9 deletions part/part_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package part

import (
"bytes"
"encoding/binary"
"fmt"
"math/rand"
Expand Down Expand Up @@ -61,6 +62,10 @@ func uint64Key(n uint64) []byte {
return binary.BigEndian.AppendUint64(nil, n)
}

func hexKey(n uint64) []byte {
return []byte(fmt.Sprintf("%x", n))
}

func uint32Key(n uint32) []byte {
return binary.BigEndian.AppendUint32(nil, n)
}
Expand Down Expand Up @@ -592,7 +597,6 @@ func Test_lowerbound_edge_cases(t *testing.T) {
next(false, 0)

// Node256
fmt.Println("node256:")
for i := 1; i < 50; i += 2 { // add less than 256 for some holes in node256.children
n := uint32(0x20000 + i)
_, _, tree = tree.Insert(uint32Key(n), n)
Expand Down Expand Up @@ -641,29 +645,119 @@ func Test_lowerbound_regression(t *testing.T) {
require.Equal(t, len(values), i)
}

func Test_prefix_regression(t *testing.T) {
// Regression test for bug where a long key and a short key was inserted and where
// the keys shared a prefix.

tree := New[string]()
_, _, tree = tree.Insert([]byte("foobar"), "foobar")
_, _, tree = tree.Insert([]byte("foo"), "foo")

s, _, found := tree.Get([]byte("foobar"))
require.True(t, found)
require.Equal(t, s, "foobar")

s, _, found = tree.Get([]byte("foo"))
require.True(t, found)
require.Equal(t, s, "foo")
}

func Test_iterate(t *testing.T) {
sizes := []int{0, 1, 10, 100, 1000, rand.Intn(1000)}
sizes := []int{0, 1, 10, 100, 1000}
for _, size := range sizes {
t.Logf("size=%d", size)
tree := New[uint64]()
keys := []uint64{}
for i := 0; i < size; i++ {
_, _, tree = tree.Insert(uint64Key(uint64(i)), uint64(i))
keys = append(keys, uint64(i))
}

rand.Shuffle(len(keys), func(i, j int) {
keys[i], keys[j] = keys[j], keys[i]
})
for _, i := range keys {
_, _, tree = tree.Insert(hexKey(uint64(i)), uint64(i))
}

iter := tree.LowerBound([]byte{})
i := uint64(0)
for _, obj, ok := iter.Next(); ok; _, obj, ok = iter.Next() {
if obj != uint64(i) {
t.Fatalf("expected %d, got %d", i, obj)
// Insert again and validate that the old value is returned
for _, i := range keys {
var old uint64
var hadOld bool
old, hadOld, tree = tree.Insert(hexKey(uint64(i)), uint64(i))
assert.True(t, hadOld, "hadOld")
assert.Equal(t, old, uint64(i))
}

// The order for the variable length keys is based on prefix,
// so we would get 0x0105 before 0x02, since it has "smaller"
// prefix. Hence we just check we see all values.
iter := tree.Iterator()
i := int(0)
for key, obj, ok := iter.Next(); ok; key, obj, ok = iter.Next() {
if !bytes.Equal(hexKey(obj), key) {
t.Fatalf("expected %x, got %x", key, hexKey(obj))
}
i++
}
require.EqualValues(t, i, size)
if !assert.Equal(t, size, i) {
tree.PrintTree()
t.FailNow()
}

_, _, ok := iter.Next()
require.False(t, ok, "expected exhausted iterator to keep returning false")

// Delete keys one at a time, in random order.
rand.Shuffle(len(keys), func(i, j int) {
keys[i], keys[j] = keys[j], keys[i]
})
txn := tree.Txn()
n := rand.Intn(20)
for i, k := range keys {
txn.Delete(hexKey(uint64(k)))

n--
if n <= 0 {
tree = txn.Commit()
txn = tree.Txn()
n = rand.Intn(20)
}

// All the rest of the keys can still be found
for _, j := range keys[i+1:] {
n, _, found := txn.Get(hexKey(j))
if !assert.True(t, found) || !assert.Equal(t, n, j) {
fmt.Println("--- new tree")
txn.PrintTree()
t.FailNow()
}
}
}

}
}

func Test_closed_chan_regression(t *testing.T) {
tree := New[uint64]()
_, _, tree = tree.Insert(hexKey(uint64(0)), uint64(0))
_, _, tree = tree.Insert(hexKey(uint64(1)), uint64(1))
_, _, tree = tree.Insert(hexKey(uint64(2)), uint64(2))
_, _, tree = tree.Insert(hexKey(uint64(3)), uint64(3))

txn := tree.Txn()
txn.Delete(hexKey(uint64(3)))
txn.Delete(hexKey(uint64(1)))
tree = txn.Commit()

// No reachable channel should be closed
for _, c := range tree.root.children() {
select {
case <-c.watch:
t.Logf("%x %p closed already", c.prefix, &c.watch)
t.FailNow()
default:
}
}
}

func Test_lowerbound_bigger(t *testing.T) {
Expand Down
Loading
Loading