Skip to content

Commit

Permalink
Fix API asymmetries with Watch variants
Browse files Browse the repository at this point in the history
Add *Watch variants to LowerBound, Prefix and All to
make them work the same way as List and Get.

Signed-off-by: Jussi Maki <[email protected]>
  • Loading branch information
joamaki committed Jul 4, 2024
1 parent 0618f42 commit 53bb7bc
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 63 deletions.
39 changes: 22 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,30 @@ func example() {
// Query the objects with a snapshot of the database.
txn := db.ReadTxn()

if obj, _, found := myObjects.Get(wtxn, IDIndex.Query(1)); found {
if obj, _, found := myObjects.Get(txn, IDIndex.Query(1)); found {
...
}

iter, watch := myObjects.All()
// Iterate all objects

iter, watch = myObjects.LowerBound(IDIndex.Query(2))
// Iterate objects with ID >= 2

iter, watch = myObjects.Prefix(IDIndex.Query(0x1000_0000))
// Iterate objects where ID is between 0x1000_0000 and 0x1fff_ffff

iter := myObjects.All()
for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() {
...
}

// Wait until the query results change.

// Iterate all objects and then wait until something changes.
iter, watch := myObjects.AllWatch(txn)
for ... {}
<-watch
// Grab a new snapshot to read the new changes.
txn = db.ReadTxn()

// Iterate objects with ID >= 2
iter, watch = myObjects.LowerBoundWatch(txn, IDIndex.Query(2))
for ... {}

// Iterate objects where ID is between 0x1000_0000 and 0x1fff_ffff
iter, watch = myObjects.PrefixWatch(txn, IDIndex.Query(0x1000_0000))
for ... {}
}
```

Expand Down Expand Up @@ -349,7 +354,7 @@ for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() { ...
<-watch
```

`Prefix` can be used to iterate over objects that match a given prefix.
`Prefix` or `PrefixWatch` can be used to iterate over objects that match a given prefix.

```go
// Prefix does a prefix search on an index. Here it returns an iterator
Expand All @@ -362,16 +367,16 @@ for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() {
<-watch
```

`LowerBound` can be used to iterate over objects that have a key equal
to or higher than given key.
`LowerBound` or `LowerBoundWatch` can be used to iterate over objects that
have a key equal to or higher than given key.

```go
// LowerBound can be used to find all objects with a key equal to or higher
// LowerBoundWatch can be used to find all objects with a key equal to or higher
// than specified key. The semantics of it depends on how the indexer works.
// For example index.Uint32 returns the big-endian or most significant byte
// first form of the integer, in other words the number 3 is the key
// []byte{0, 0, 0, 3}, which allows doing a meaningful LowerBound search on it.
iter, watch = myObjects.LowerBound(txn, IDIndex.Query(3))
iter, watch = myObjects.LowerBoundWatch(txn, IDIndex.Query(3))
for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() {
// obj.ID >= 3
}
Expand All @@ -397,7 +402,7 @@ with `ByRevision`.
// we can use this to wait for new changes!
lastRevision := statedb.Revision(0)
for {
iter, watch = myObjects.LowerBound(txn, statedb.ByRevision(lastRevision+1))
iter, watch = myObjects.LowerBoundWatch(txn, statedb.ByRevision(lastRevision+1))
for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() {
lastRevision = revision
}
Expand Down
6 changes: 3 additions & 3 deletions benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func BenchmarkDB_FullIteration_All(b *testing.B) {

for j := 0; j < b.N; j++ {
txn := db.ReadTxn()
iter, _ := table.All(txn)
iter := table.All(txn)
i := uint64(0)
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
if obj.ID != i {
Expand Down Expand Up @@ -449,11 +449,11 @@ func BenchmarkDB_PropagationDelay(b *testing.B) {

// Grab a watch channel on the second table
txn := db.ReadTxn()
_, watch2 := table2.All(txn)
_, watch2 := table2.AllWatch(txn)

// Propagate the batch from first table to the second table
var iter Iterator[testObject]
iter, watch1 = table1.LowerBound(txn, ByRevision[testObject](revision))
iter, watch1 = table1.LowerBoundWatch(txn, ByRevision[testObject](revision))
wtxn = db.WriteTxn(table2)
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
table2.Insert(wtxn, testObject2(obj))
Expand Down
35 changes: 15 additions & 20 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestDB_LowerBound_ByRevision(t *testing.T) {

txn := db.ReadTxn()

iter, watch := table.LowerBound(txn, ByRevision[testObject](0))
iter, watch := table.LowerBoundWatch(txn, ByRevision[testObject](0))
obj, rev, ok := iter.Next()
require.True(t, ok, "expected ByRevision(rev1) to return results")
require.EqualValues(t, 42, obj.ID)
Expand All @@ -181,7 +181,7 @@ func TestDB_LowerBound_ByRevision(t *testing.T) {
_, _, ok = iter.Next()
require.False(t, ok)

iter, _ = table.LowerBound(txn, ByRevision[testObject](prevRev+1))
iter, _ = table.LowerBoundWatch(txn, ByRevision[testObject](prevRev+1))
obj, _, ok = iter.Next()
require.True(t, ok, "expected ByRevision(rev2) to return results")
require.EqualValues(t, 71, obj.ID)
Expand All @@ -208,7 +208,7 @@ func TestDB_LowerBound_ByRevision(t *testing.T) {
}

txn = db.ReadTxn()
iter, _ = table.LowerBound(txn, ByRevision[testObject](rev+1))
iter, _ = table.LowerBoundWatch(txn, ByRevision[testObject](rev+1))
obj, _, ok = iter.Next()
require.True(t, ok, "expected ByRevision(rev2+1) to return results")
require.EqualValues(t, 71, obj.ID)
Expand All @@ -234,7 +234,7 @@ func TestDB_Prefix(t *testing.T) {

txn := db.ReadTxn()

iter, watch := table.Prefix(txn, tagsIndex.Query("ab"))
iter, watch := table.PrefixWatch(txn, tagsIndex.Query("ab"))
require.Equal(t, Collect(Map(iter, testObject.getID)), []uint64{71, 82})

select {
Expand Down Expand Up @@ -270,7 +270,7 @@ func TestDB_Prefix(t *testing.T) {
}

txn = db.ReadTxn()
iter, _ = table.Prefix(txn, tagsIndex.Query("ab"))
iter = table.Prefix(txn, tagsIndex.Query("ab"))
require.Equal(t, Collect(Map(iter, testObject.getID)), []uint64{71, 82, 99})
}

Expand Down Expand Up @@ -332,7 +332,7 @@ func TestDB_EventIterator(t *testing.T) {

// 1 object should exist.
txn := db.ReadTxn()
iterAll, _ := table.All(txn)
iterAll := table.All(txn)
objs := Collect(iterAll)
require.Len(t, objs, 1)

Expand Down Expand Up @@ -539,7 +539,7 @@ func TestDB_All(t *testing.T) {
require.NoError(t, err, "Insert failed")
_, _, err = table.Insert(txn, testObject{ID: uint64(3)})
require.NoError(t, err, "Insert failed")
iter, _ := table.All(txn)
iter := table.All(txn)
objs := Collect(iter)
require.Len(t, objs, 3)
require.EqualValues(t, 1, objs[0].ID)
Expand All @@ -549,7 +549,7 @@ func TestDB_All(t *testing.T) {
}

txn := db.ReadTxn()
iter, watch := table.All(txn)
iter, watch := table.AllWatch(txn)
objs := Collect(iter)
require.Len(t, objs, 3)
require.EqualValues(t, 1, objs[0].ID)
Expand All @@ -571,7 +571,7 @@ func TestDB_All(t *testing.T) {
}

// Prior read transaction not affected by delete.
iter, _ = table.All(txn)
iter = table.All(txn)
objs = Collect(iter)
require.Len(t, objs, 3)

Expand Down Expand Up @@ -764,7 +764,7 @@ func TestDB_CompareAndSwap_CompareAndDelete(t *testing.T) {
require.ErrorIs(t, ErrObjectNotFound, err)
require.False(t, hadOld)

objs, _ := table.All(wtxn)
objs := table.All(wtxn)
require.Len(t, Collect(objs), 0)

wtxn.Abort()
Expand Down Expand Up @@ -852,29 +852,24 @@ func TestDB_ReadAfterWrite(t *testing.T) {

txn := db.WriteTxn(table)

iter, _ := table.All(txn)
require.Len(t, Collect(iter), 0)
require.Len(t, Collect(table.All(txn)), 0)

_, _, err := table.Insert(txn, testObject{ID: 1})
require.NoError(t, err, "Insert failed")

iter, _ = table.All(txn)
require.Len(t, Collect(iter), 1)
require.Len(t, Collect(table.All(txn)), 1)

_, hadOld, _ := table.Delete(txn, testObject{ID: 1})
require.True(t, hadOld)
iter, _ = table.All(txn)
require.Len(t, Collect(iter), 0)
require.Len(t, Collect(table.All(txn)), 0)

_, _, err = table.Insert(txn, testObject{ID: 2})
require.NoError(t, err, "Insert failed")
iter, _ = table.All(txn)
require.Len(t, Collect(iter), 1)
require.Len(t, Collect(table.All(txn)), 1)

txn.Commit()

iter, _ = table.All(db.ReadTxn())
require.Len(t, Collect(iter), 1)
require.Len(t, Collect(table.All(db.ReadTxn())), 1)
}

func TestDB_Initialization(t *testing.T) {
Expand Down
3 changes: 1 addition & 2 deletions derive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ func TestDerive(t *testing.T) {

getDerived := func() []derived {
txn := db.ReadTxn()
iter, _ := outTable.All(txn)
objs := Collect(iter)
objs := Collect(outTable.All(txn))
// Log so we can trace the failed eventually calls
t.Logf("derived: %+v", objs)
return objs
Expand Down
12 changes: 6 additions & 6 deletions fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (a *realActionLog) validateTable(txn statedb.ReadTxn, table statedb.Table[f
// Since everything was deleted we can clear the log entries for this table now
a.log[table.Name()] = nil

iter, _ := table.All(txn)
iter := table.All(txn)
actual := map[string]struct{}{}
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
actual[obj.id] = struct{}{}
Expand Down Expand Up @@ -247,7 +247,7 @@ func deleteManyAction(ctx actionContext) {
// nothing bad happens when the iterator is used while deleting.
toDelete := ctx.table.NumObjects(ctx.txn) / 3

iter, _ := ctx.table.All(ctx.txn)
iter := ctx.table.All(ctx.txn)
n := 0
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
ctx.log.log("%s: DeleteMany %s (%d/%d)", ctx.table.Name(), obj.id, n+1, toDelete)
Expand All @@ -267,7 +267,7 @@ func deleteManyAction(ctx actionContext) {
}

func allAction(ctx actionContext) {
iter, _ := ctx.table.All(ctx.txn)
iter := ctx.table.All(ctx.txn)
ctx.log.log("%s: All => %d found", ctx.table.Name(), len(statedb.Collect(iter)))
}

Expand Down Expand Up @@ -319,13 +319,13 @@ func getAction(ctx actionContext) {

func lowerboundAction(ctx actionContext) {
id := mkID()
iter, _ := ctx.table.LowerBound(ctx.txn, idIndex.Query(id))
iter, _ := ctx.table.LowerBoundWatch(ctx.txn, idIndex.Query(id))
ctx.log.log("%s: LowerBound(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
}

func prefixAction(ctx actionContext) {
id := mkID()
iter, _ := ctx.table.Prefix(ctx.txn, idIndex.Query(id))
iter := ctx.table.Prefix(ctx.txn, idIndex.Query(id))
ctx.log.log("%s: Prefix(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter)))
}

Expand Down Expand Up @@ -417,7 +417,7 @@ func trackerWorker(i int, stop <-chan struct{}) {
// Validate that the observed changes match with the database state at this
// snapshot.
state2 := maps.Clone(state)
iterAll, _ := tableFuzz1.LowerBound(txn, statedb.ByRevision[fuzzObj](0))
iterAll := tableFuzz1.LowerBound(txn, statedb.ByRevision[fuzzObj](0))
for obj, rev, ok := iterAll.Next(); ok; obj, rev, ok = iterAll.Next() {
change, found := state[obj.id]
if !found {
Expand Down
2 changes: 1 addition & 1 deletion iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (it *changeIterator[Obj]) Watch(txn ReadTxn) <-chan struct{} {
return it.watch
}

updateIter, watch := it.table.LowerBound(txn, ByRevision[Obj](it.revision+1))
updateIter, watch := it.table.LowerBoundWatch(txn, ByRevision[Obj](it.revision+1))
deleteIter := it.dt.deleted(txn, it.revision+1)
it.iter = NewDualIterator(deleteIter, updateIter)

Expand Down
2 changes: 1 addition & 1 deletion iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestFilter(t *testing.T) {
table.Insert(txn, &testObject{ID: 5})
txn.Commit()

iter, _ := table.All(db.ReadTxn())
iter := table.All(db.ReadTxn())
filtered := Collect(
Map(
Filter(
Expand Down
2 changes: 1 addition & 1 deletion reconciler/benchmark/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func main() {
}

// Check that all statuses are correctly set.
iter, _ := testObjects.All(db.ReadTxn())
iter := testObjects.All(db.ReadTxn())
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
if obj.status.Kind != reconciler.StatusKindDone {
log.Fatalf("Object with unexpected status: %#v", obj)
Expand Down
4 changes: 2 additions & 2 deletions reconciler/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (r *reconciler[Obj]) reconcileLoop(ctx context.Context, health cell.Health)

// prune performs the Prune operation to delete unexpected objects in the target system.
func (r *reconciler[Obj]) prune(ctx context.Context, txn statedb.ReadTxn) error {
iter, _ := r.config.Table.All(txn)
iter := r.config.Table.All(txn)
start := time.Now()
err := r.config.Operations.Prune(ctx, txn, iter)
if err != nil {
Expand Down Expand Up @@ -145,7 +145,7 @@ outer:
// Iterate over the objects in revision order, e.g. oldest modification first.
// We look for objects that are older than [RefreshInterval] and mark them for
// pending in order for them to be reconciled again.
iter, _ := r.config.Table.LowerBound(r.DB.ReadTxn(), statedb.ByRevision[Obj](lastRevision+1))
iter := r.config.Table.LowerBound(r.DB.ReadTxn(), statedb.ByRevision[Obj](lastRevision+1))
indexer := r.config.Table.PrimaryIndexer()
for obj, rev, ok := iter.Next(); ok; obj, rev, ok = iter.Next() {
status := r.config.GetObjectStatus(obj)
Expand Down
27 changes: 21 additions & 6 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ func (t *genTable[Obj]) GetWatch(txn ReadTxn, q Query[Obj]) (obj Obj, revision u
return
}

func (t *genTable[Obj]) LowerBound(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) {
func (t *genTable[Obj]) LowerBound(txn ReadTxn, q Query[Obj]) Iterator[Obj] {
iter, _ := t.LowerBoundWatch(txn, q)
return iter
}

func (t *genTable[Obj]) LowerBoundWatch(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) {
indexTxn := txn.getTxn().mustIndexReadTxn(t, t.indexPos(q.index))
// Since LowerBound query may be invalidated by changes in another branch
// of the tree, we cannot just simply watch the node we seeked to. Instead
Expand All @@ -260,13 +265,23 @@ func (t *genTable[Obj]) LowerBound(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-
return &iterator[Obj]{iter}, watch
}

func (t *genTable[Obj]) Prefix(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) {
func (t *genTable[Obj]) Prefix(txn ReadTxn, q Query[Obj]) Iterator[Obj] {
iter, _ := t.PrefixWatch(txn, q)
return iter
}

func (t *genTable[Obj]) PrefixWatch(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) {
indexTxn := txn.getTxn().mustIndexReadTxn(t, t.indexPos(q.index))
iter, watch := indexTxn.Prefix(q.key)
return &iterator[Obj]{iter}, watch
}

func (t *genTable[Obj]) All(txn ReadTxn) (Iterator[Obj], <-chan struct{}) {
func (t *genTable[Obj]) All(txn ReadTxn) Iterator[Obj] {
iter, _ := t.AllWatch(txn)
return iter
}

func (t *genTable[Obj]) AllWatch(txn ReadTxn) (Iterator[Obj], <-chan struct{}) {
indexTxn := txn.getTxn().mustIndexReadTxn(t, PrimaryIndexPos)
watch := indexTxn.RootWatch()
return &iterator[Obj]{indexTxn.Iterator()}, watch
Expand Down Expand Up @@ -323,7 +338,7 @@ func (t *genTable[Obj]) CompareAndDelete(txn WriteTxn, rev Revision, obj Obj) (o
}

func (t *genTable[Obj]) DeleteAll(txn WriteTxn) error {
iter, _ := t.All(txn)
iter := t.All(txn)
itxn := txn.getTxn()
for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() {
_, _, err := itxn.delete(t, Revision(0), obj)
Expand Down Expand Up @@ -354,8 +369,8 @@ func (t *genTable[Obj]) Changes(txn WriteTxn) (ChangeIterator[Obj], error) {
}

// Prepare the iterator
updateIter, watch := t.LowerBound(txn, ByRevision[Obj](0)) // observe all current objects
deleteIter := iter.dt.deleted(txn, iter.dt.getRevision()) // only observe new deletions
updateIter, watch := t.LowerBoundWatch(txn, ByRevision[Obj](0)) // observe all current objects
deleteIter := iter.dt.deleted(txn, iter.dt.getRevision()) // only observe new deletions
iter.iter = NewDualIterator(deleteIter, updateIter)
iter.watch = watch

Expand Down
Loading

0 comments on commit 53bb7bc

Please sign in to comment.