From 53bb7bc0dd7d810de6dcb84d079fe7fb4d600cae Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Wed, 3 Jul 2024 16:27:18 +0200 Subject: [PATCH] Fix API asymmetries with Watch variants Add *Watch variants to LowerBound, Prefix and All to make them work the same way as List and Get. Signed-off-by: Jussi Maki --- README.md | 39 ++++++++++++++++++++---------------- benchmarks_test.go | 6 +++--- db_test.go | 35 ++++++++++++++------------------ derive_test.go | 3 +-- fuzz_test.go | 12 +++++------ iterator.go | 2 +- iterator_test.go | 2 +- reconciler/benchmark/main.go | 2 +- reconciler/reconciler.go | 4 ++-- table.go | 27 +++++++++++++++++++------ types.go | 19 ++++++++++++++---- 11 files changed, 88 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 38dc21e..e622511 100644 --- a/README.md +++ b/README.md @@ -111,25 +111,30 @@ func example() { // Query the objects with a snapshot of the database. txn := db.ReadTxn() - if obj, _, found := myObjects.Get(wtxn, IDIndex.Query(1)); found { + if obj, _, found := myObjects.Get(txn, IDIndex.Query(1)); found { ... } - iter, watch := myObjects.All() // Iterate all objects - - iter, watch = myObjects.LowerBound(IDIndex.Query(2)) - // Iterate objects with ID >= 2 - - iter, watch = myObjects.Prefix(IDIndex.Query(0x1000_0000)) - // Iterate objects where ID is between 0x1000_0000 and 0x1fff_ffff - + iter := myObjects.All() for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() { ... } - - // Wait until the query results change. + + // Iterate all objects and then wait until something changes. + iter, watch := myObjects.AllWatch(txn) + for ... {} <-watch + // Grab a new snapshot to read the new changes. + txn = db.ReadTxn() + + // Iterate objects with ID >= 2 + iter, watch = myObjects.LowerBoundWatch(txn, IDIndex.Query(2)) + for ... {} + + // Iterate objects where ID is between 0x1000_0000 and 0x1fff_ffff + iter, watch = myObjects.PrefixWatch(txn, IDIndex.Query(0x1000_0000)) + for ... {} } ``` @@ -349,7 +354,7 @@ for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() { ... <-watch ``` -`Prefix` can be used to iterate over objects that match a given prefix. +`Prefix` or `PrefixWatch` can be used to iterate over objects that match a given prefix. ```go // Prefix does a prefix search on an index. Here it returns an iterator @@ -362,16 +367,16 @@ for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() { <-watch ``` -`LowerBound` can be used to iterate over objects that have a key equal -to or higher than given key. +`LowerBound` or `LowerBoundWatch` can be used to iterate over objects that +have a key equal to or higher than given key. ```go -// LowerBound can be used to find all objects with a key equal to or higher +// LowerBoundWatch can be used to find all objects with a key equal to or higher // than specified key. The semantics of it depends on how the indexer works. // For example index.Uint32 returns the big-endian or most significant byte // first form of the integer, in other words the number 3 is the key // []byte{0, 0, 0, 3}, which allows doing a meaningful LowerBound search on it. -iter, watch = myObjects.LowerBound(txn, IDIndex.Query(3)) +iter, watch = myObjects.LowerBoundWatch(txn, IDIndex.Query(3)) for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() { // obj.ID >= 3 } @@ -397,7 +402,7 @@ with `ByRevision`. // we can use this to wait for new changes! lastRevision := statedb.Revision(0) for { - iter, watch = myObjects.LowerBound(txn, statedb.ByRevision(lastRevision+1)) + iter, watch = myObjects.LowerBoundWatch(txn, statedb.ByRevision(lastRevision+1)) for obj, revision, ok := iter.Next(); ok; obj, revision, ok = iter.Next() { lastRevision = revision } diff --git a/benchmarks_test.go b/benchmarks_test.go index ab1f974..9c87250 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -342,7 +342,7 @@ func BenchmarkDB_FullIteration_All(b *testing.B) { for j := 0; j < b.N; j++ { txn := db.ReadTxn() - iter, _ := table.All(txn) + iter := table.All(txn) i := uint64(0) for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() { if obj.ID != i { @@ -449,11 +449,11 @@ func BenchmarkDB_PropagationDelay(b *testing.B) { // Grab a watch channel on the second table txn := db.ReadTxn() - _, watch2 := table2.All(txn) + _, watch2 := table2.AllWatch(txn) // Propagate the batch from first table to the second table var iter Iterator[testObject] - iter, watch1 = table1.LowerBound(txn, ByRevision[testObject](revision)) + iter, watch1 = table1.LowerBoundWatch(txn, ByRevision[testObject](revision)) wtxn = db.WriteTxn(table2) for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() { table2.Insert(wtxn, testObject2(obj)) diff --git a/db_test.go b/db_test.go index a271c35..33e08eb 100644 --- a/db_test.go +++ b/db_test.go @@ -169,7 +169,7 @@ func TestDB_LowerBound_ByRevision(t *testing.T) { txn := db.ReadTxn() - iter, watch := table.LowerBound(txn, ByRevision[testObject](0)) + iter, watch := table.LowerBoundWatch(txn, ByRevision[testObject](0)) obj, rev, ok := iter.Next() require.True(t, ok, "expected ByRevision(rev1) to return results") require.EqualValues(t, 42, obj.ID) @@ -181,7 +181,7 @@ func TestDB_LowerBound_ByRevision(t *testing.T) { _, _, ok = iter.Next() require.False(t, ok) - iter, _ = table.LowerBound(txn, ByRevision[testObject](prevRev+1)) + iter, _ = table.LowerBoundWatch(txn, ByRevision[testObject](prevRev+1)) obj, _, ok = iter.Next() require.True(t, ok, "expected ByRevision(rev2) to return results") require.EqualValues(t, 71, obj.ID) @@ -208,7 +208,7 @@ func TestDB_LowerBound_ByRevision(t *testing.T) { } txn = db.ReadTxn() - iter, _ = table.LowerBound(txn, ByRevision[testObject](rev+1)) + iter, _ = table.LowerBoundWatch(txn, ByRevision[testObject](rev+1)) obj, _, ok = iter.Next() require.True(t, ok, "expected ByRevision(rev2+1) to return results") require.EqualValues(t, 71, obj.ID) @@ -234,7 +234,7 @@ func TestDB_Prefix(t *testing.T) { txn := db.ReadTxn() - iter, watch := table.Prefix(txn, tagsIndex.Query("ab")) + iter, watch := table.PrefixWatch(txn, tagsIndex.Query("ab")) require.Equal(t, Collect(Map(iter, testObject.getID)), []uint64{71, 82}) select { @@ -270,7 +270,7 @@ func TestDB_Prefix(t *testing.T) { } txn = db.ReadTxn() - iter, _ = table.Prefix(txn, tagsIndex.Query("ab")) + iter = table.Prefix(txn, tagsIndex.Query("ab")) require.Equal(t, Collect(Map(iter, testObject.getID)), []uint64{71, 82, 99}) } @@ -332,7 +332,7 @@ func TestDB_EventIterator(t *testing.T) { // 1 object should exist. txn := db.ReadTxn() - iterAll, _ := table.All(txn) + iterAll := table.All(txn) objs := Collect(iterAll) require.Len(t, objs, 1) @@ -539,7 +539,7 @@ func TestDB_All(t *testing.T) { require.NoError(t, err, "Insert failed") _, _, err = table.Insert(txn, testObject{ID: uint64(3)}) require.NoError(t, err, "Insert failed") - iter, _ := table.All(txn) + iter := table.All(txn) objs := Collect(iter) require.Len(t, objs, 3) require.EqualValues(t, 1, objs[0].ID) @@ -549,7 +549,7 @@ func TestDB_All(t *testing.T) { } txn := db.ReadTxn() - iter, watch := table.All(txn) + iter, watch := table.AllWatch(txn) objs := Collect(iter) require.Len(t, objs, 3) require.EqualValues(t, 1, objs[0].ID) @@ -571,7 +571,7 @@ func TestDB_All(t *testing.T) { } // Prior read transaction not affected by delete. - iter, _ = table.All(txn) + iter = table.All(txn) objs = Collect(iter) require.Len(t, objs, 3) @@ -764,7 +764,7 @@ func TestDB_CompareAndSwap_CompareAndDelete(t *testing.T) { require.ErrorIs(t, ErrObjectNotFound, err) require.False(t, hadOld) - objs, _ := table.All(wtxn) + objs := table.All(wtxn) require.Len(t, Collect(objs), 0) wtxn.Abort() @@ -852,29 +852,24 @@ func TestDB_ReadAfterWrite(t *testing.T) { txn := db.WriteTxn(table) - iter, _ := table.All(txn) - require.Len(t, Collect(iter), 0) + require.Len(t, Collect(table.All(txn)), 0) _, _, err := table.Insert(txn, testObject{ID: 1}) require.NoError(t, err, "Insert failed") - iter, _ = table.All(txn) - require.Len(t, Collect(iter), 1) + require.Len(t, Collect(table.All(txn)), 1) _, hadOld, _ := table.Delete(txn, testObject{ID: 1}) require.True(t, hadOld) - iter, _ = table.All(txn) - require.Len(t, Collect(iter), 0) + require.Len(t, Collect(table.All(txn)), 0) _, _, err = table.Insert(txn, testObject{ID: 2}) require.NoError(t, err, "Insert failed") - iter, _ = table.All(txn) - require.Len(t, Collect(iter), 1) + require.Len(t, Collect(table.All(txn)), 1) txn.Commit() - iter, _ = table.All(db.ReadTxn()) - require.Len(t, Collect(iter), 1) + require.Len(t, Collect(table.All(db.ReadTxn())), 1) } func TestDB_Initialization(t *testing.T) { diff --git a/derive_test.go b/derive_test.go index b1d4b01..1adb16d 100644 --- a/derive_test.go +++ b/derive_test.go @@ -106,8 +106,7 @@ func TestDerive(t *testing.T) { getDerived := func() []derived { txn := db.ReadTxn() - iter, _ := outTable.All(txn) - objs := Collect(iter) + objs := Collect(outTable.All(txn)) // Log so we can trace the failed eventually calls t.Logf("derived: %+v", objs) return objs diff --git a/fuzz_test.go b/fuzz_test.go index a42e205..1ea3695 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -140,7 +140,7 @@ func (a *realActionLog) validateTable(txn statedb.ReadTxn, table statedb.Table[f // Since everything was deleted we can clear the log entries for this table now a.log[table.Name()] = nil - iter, _ := table.All(txn) + iter := table.All(txn) actual := map[string]struct{}{} for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() { actual[obj.id] = struct{}{} @@ -247,7 +247,7 @@ func deleteManyAction(ctx actionContext) { // nothing bad happens when the iterator is used while deleting. toDelete := ctx.table.NumObjects(ctx.txn) / 3 - iter, _ := ctx.table.All(ctx.txn) + iter := ctx.table.All(ctx.txn) n := 0 for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() { ctx.log.log("%s: DeleteMany %s (%d/%d)", ctx.table.Name(), obj.id, n+1, toDelete) @@ -267,7 +267,7 @@ func deleteManyAction(ctx actionContext) { } func allAction(ctx actionContext) { - iter, _ := ctx.table.All(ctx.txn) + iter := ctx.table.All(ctx.txn) ctx.log.log("%s: All => %d found", ctx.table.Name(), len(statedb.Collect(iter))) } @@ -319,13 +319,13 @@ func getAction(ctx actionContext) { func lowerboundAction(ctx actionContext) { id := mkID() - iter, _ := ctx.table.LowerBound(ctx.txn, idIndex.Query(id)) + iter, _ := ctx.table.LowerBoundWatch(ctx.txn, idIndex.Query(id)) ctx.log.log("%s: LowerBound(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter))) } func prefixAction(ctx actionContext) { id := mkID() - iter, _ := ctx.table.Prefix(ctx.txn, idIndex.Query(id)) + iter := ctx.table.Prefix(ctx.txn, idIndex.Query(id)) ctx.log.log("%s: Prefix(%s) => %d found", ctx.table.Name(), id, len(statedb.Collect(iter))) } @@ -417,7 +417,7 @@ func trackerWorker(i int, stop <-chan struct{}) { // Validate that the observed changes match with the database state at this // snapshot. state2 := maps.Clone(state) - iterAll, _ := tableFuzz1.LowerBound(txn, statedb.ByRevision[fuzzObj](0)) + iterAll := tableFuzz1.LowerBound(txn, statedb.ByRevision[fuzzObj](0)) for obj, rev, ok := iterAll.Next(); ok; obj, rev, ok = iterAll.Next() { change, found := state[obj.id] if !found { diff --git a/iterator.go b/iterator.go index 9895c2a..cdc1a4b 100644 --- a/iterator.go +++ b/iterator.go @@ -244,7 +244,7 @@ func (it *changeIterator[Obj]) Watch(txn ReadTxn) <-chan struct{} { return it.watch } - updateIter, watch := it.table.LowerBound(txn, ByRevision[Obj](it.revision+1)) + updateIter, watch := it.table.LowerBoundWatch(txn, ByRevision[Obj](it.revision+1)) deleteIter := it.dt.deleted(txn, it.revision+1) it.iter = NewDualIterator(deleteIter, updateIter) diff --git a/iterator_test.go b/iterator_test.go index 8344684..e2a504b 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -38,7 +38,7 @@ func TestFilter(t *testing.T) { table.Insert(txn, &testObject{ID: 5}) txn.Commit() - iter, _ := table.All(db.ReadTxn()) + iter := table.All(db.ReadTxn()) filtered := Collect( Map( Filter( diff --git a/reconciler/benchmark/main.go b/reconciler/benchmark/main.go index f70d8c6..48c901a 100644 --- a/reconciler/benchmark/main.go +++ b/reconciler/benchmark/main.go @@ -205,7 +205,7 @@ func main() { } // Check that all statuses are correctly set. - iter, _ := testObjects.All(db.ReadTxn()) + iter := testObjects.All(db.ReadTxn()) for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() { if obj.status.Kind != reconciler.StatusKindDone { log.Fatalf("Object with unexpected status: %#v", obj) diff --git a/reconciler/reconciler.go b/reconciler/reconciler.go index 1453b27..adf3cf6 100644 --- a/reconciler/reconciler.go +++ b/reconciler/reconciler.go @@ -111,7 +111,7 @@ func (r *reconciler[Obj]) reconcileLoop(ctx context.Context, health cell.Health) // prune performs the Prune operation to delete unexpected objects in the target system. func (r *reconciler[Obj]) prune(ctx context.Context, txn statedb.ReadTxn) error { - iter, _ := r.config.Table.All(txn) + iter := r.config.Table.All(txn) start := time.Now() err := r.config.Operations.Prune(ctx, txn, iter) if err != nil { @@ -145,7 +145,7 @@ outer: // Iterate over the objects in revision order, e.g. oldest modification first. // We look for objects that are older than [RefreshInterval] and mark them for // pending in order for them to be reconciled again. - iter, _ := r.config.Table.LowerBound(r.DB.ReadTxn(), statedb.ByRevision[Obj](lastRevision+1)) + iter := r.config.Table.LowerBound(r.DB.ReadTxn(), statedb.ByRevision[Obj](lastRevision+1)) indexer := r.config.Table.PrimaryIndexer() for obj, rev, ok := iter.Next(); ok; obj, rev, ok = iter.Next() { status := r.config.GetObjectStatus(obj) diff --git a/table.go b/table.go index cfded7a..8ded881 100644 --- a/table.go +++ b/table.go @@ -250,7 +250,12 @@ func (t *genTable[Obj]) GetWatch(txn ReadTxn, q Query[Obj]) (obj Obj, revision u return } -func (t *genTable[Obj]) LowerBound(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) { +func (t *genTable[Obj]) LowerBound(txn ReadTxn, q Query[Obj]) Iterator[Obj] { + iter, _ := t.LowerBoundWatch(txn, q) + return iter +} + +func (t *genTable[Obj]) LowerBoundWatch(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) { indexTxn := txn.getTxn().mustIndexReadTxn(t, t.indexPos(q.index)) // Since LowerBound query may be invalidated by changes in another branch // of the tree, we cannot just simply watch the node we seeked to. Instead @@ -260,13 +265,23 @@ func (t *genTable[Obj]) LowerBound(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <- return &iterator[Obj]{iter}, watch } -func (t *genTable[Obj]) Prefix(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) { +func (t *genTable[Obj]) Prefix(txn ReadTxn, q Query[Obj]) Iterator[Obj] { + iter, _ := t.PrefixWatch(txn, q) + return iter +} + +func (t *genTable[Obj]) PrefixWatch(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan struct{}) { indexTxn := txn.getTxn().mustIndexReadTxn(t, t.indexPos(q.index)) iter, watch := indexTxn.Prefix(q.key) return &iterator[Obj]{iter}, watch } -func (t *genTable[Obj]) All(txn ReadTxn) (Iterator[Obj], <-chan struct{}) { +func (t *genTable[Obj]) All(txn ReadTxn) Iterator[Obj] { + iter, _ := t.AllWatch(txn) + return iter +} + +func (t *genTable[Obj]) AllWatch(txn ReadTxn) (Iterator[Obj], <-chan struct{}) { indexTxn := txn.getTxn().mustIndexReadTxn(t, PrimaryIndexPos) watch := indexTxn.RootWatch() return &iterator[Obj]{indexTxn.Iterator()}, watch @@ -323,7 +338,7 @@ func (t *genTable[Obj]) CompareAndDelete(txn WriteTxn, rev Revision, obj Obj) (o } func (t *genTable[Obj]) DeleteAll(txn WriteTxn) error { - iter, _ := t.All(txn) + iter := t.All(txn) itxn := txn.getTxn() for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() { _, _, err := itxn.delete(t, Revision(0), obj) @@ -354,8 +369,8 @@ func (t *genTable[Obj]) Changes(txn WriteTxn) (ChangeIterator[Obj], error) { } // Prepare the iterator - updateIter, watch := t.LowerBound(txn, ByRevision[Obj](0)) // observe all current objects - deleteIter := iter.dt.deleted(txn, iter.dt.getRevision()) // only observe new deletions + updateIter, watch := t.LowerBoundWatch(txn, ByRevision[Obj](0)) // observe all current objects + deleteIter := iter.dt.deleted(txn, iter.dt.getRevision()) // only observe new deletions iter.iter = NewDualIterator(deleteIter, updateIter) iter.watch = watch diff --git a/types.go b/types.go index 6cbbb7f..f0e808d 100644 --- a/types.go +++ b/types.go @@ -42,9 +42,12 @@ type Table[Obj any] interface { // increments in a write transaction on each Insert and Delete. Revision(ReadTxn) Revision - // All returns an iterator for all objects in the table and a watch + // All returns an iterator for all objects in the table. + All(ReadTxn) Iterator[Obj] + + // AllWatch returns an iterator for all objects in the table and a watch // channel that is closed when the table changes. - All(ReadTxn) (Iterator[Obj], <-chan struct{}) + AllWatch(ReadTxn) (Iterator[Obj], <-chan struct{}) // List returns an iterator for all objects matching the given query. List(ReadTxn, Query[Obj]) Iterator[Obj] @@ -62,13 +65,21 @@ type Table[Obj any] interface { GetWatch(ReadTxn, Query[Obj]) (obj Obj, rev Revision, watch <-chan struct{}, found bool) // LowerBound returns an iterator for objects that have a key + // greater or equal to the query. + LowerBound(ReadTxn, Query[Obj]) Iterator[Obj] + + // LowerBoundWatch returns an iterator for objects that have a key // greater or equal to the query. The returned watch channel is closed // when anything in the table changes as more fine-grained notifications // are not possible with a lower bound search. - LowerBound(ReadTxn, Query[Obj]) (iter Iterator[Obj], watch <-chan struct{}) + LowerBoundWatch(ReadTxn, Query[Obj]) (iter Iterator[Obj], watch <-chan struct{}) // Prefix searches the table by key prefix. - Prefix(ReadTxn, Query[Obj]) (iter Iterator[Obj], watch <-chan struct{}) + Prefix(ReadTxn, Query[Obj]) Iterator[Obj] + + // PrefixWatch searches the table by key prefix. Returns an iterator and a watch + // channel that closes when the query results have become stale. + PrefixWatch(ReadTxn, Query[Obj]) (iter Iterator[Obj], watch <-chan struct{}) // Changes returns an iterator for changes happening to the table. // This uses the revision index to iterate over the objects in the order