Skip to content

Commit

Permalink
Fix the old object returned from CompareAndSwap
Browse files Browse the repository at this point in the history
On revision mismatch 'object' was returned as the old
object instead of 'obj.data'. To avoid this happening
again change the return type to 'object' instead of 'any'
and add checks for the returned old object from CompareAndSwap
and CompareAndDelete.

Signed-off-by: Jussi Maki <[email protected]>
  • Loading branch information
joamaki committed Feb 20, 2024
1 parent 688a481 commit 10f9eb6
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 30 deletions.
12 changes: 8 additions & 4 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,10 @@ func TestDB_CompareAndSwap_CompareAndDelete(t *testing.T) {
// Updating an object with matching revision number works
wtxn = db.WriteTxn(table)
obj.Tags = []string{"updated"} // NOTE: testObject stored by value so no explicit copy needed.
_, hadOld, err := table.CompareAndSwap(wtxn, rev1, obj)
oldObj, hadOld, err := table.CompareAndSwap(wtxn, rev1, obj)
require.NoError(t, err)
require.True(t, hadOld)
require.EqualValues(t, 1, oldObj.ID)
wtxn.Commit()

obj, _, ok = table.First(db.ReadTxn(), idIndex.Query(1))
Expand All @@ -648,9 +649,10 @@ func TestDB_CompareAndSwap_CompareAndDelete(t *testing.T) {
// Updating an object with mismatching revision number fails
wtxn = db.WriteTxn(table)
obj.Tags = []string{"mismatch"}
_, hadOld, err = table.CompareAndSwap(wtxn, rev1, obj)
oldObj, hadOld, err = table.CompareAndSwap(wtxn, rev1, obj)
require.ErrorIs(t, ErrRevisionNotEqual, err)
require.True(t, hadOld)
require.EqualValues(t, 1, oldObj.ID)
wtxn.Commit()

obj, _, ok = table.First(db.ReadTxn(), idIndex.Query(1))
Expand All @@ -661,9 +663,10 @@ func TestDB_CompareAndSwap_CompareAndDelete(t *testing.T) {
// Deleting an object with mismatching revision number fails
wtxn = db.WriteTxn(table)
obj.Tags = []string{"mismatch"}
_, hadOld, err = table.CompareAndDelete(wtxn, rev1, obj)
oldObj, hadOld, err = table.CompareAndDelete(wtxn, rev1, obj)
require.ErrorIs(t, ErrRevisionNotEqual, err)
require.True(t, hadOld)
require.EqualValues(t, 1, oldObj.ID)
wtxn.Commit()

obj, rev2, ok := table.First(db.ReadTxn(), idIndex.Query(1))
Expand All @@ -674,9 +677,10 @@ func TestDB_CompareAndSwap_CompareAndDelete(t *testing.T) {
// Deleting with matching revision number works
wtxn = db.WriteTxn(table)
obj.Tags = []string{"mismatch"}
_, hadOld, err = table.CompareAndDelete(wtxn, rev2, obj)
oldObj, hadOld, err = table.CompareAndDelete(wtxn, rev2, obj)
require.NoError(t, err)
require.True(t, hadOld)
require.EqualValues(t, 1, oldObj.ID)
wtxn.Commit()

_, _, ok = table.First(db.ReadTxn(), idIndex.Query(1))
Expand Down
32 changes: 16 additions & 16 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,37 +237,37 @@ func (t *genTable[Obj]) Get(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan st
}

func (t *genTable[Obj]) Insert(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, err error) {
var data any
data, hadOld, err = txn.getTxn().Insert(t, Revision(0), obj)
if err == nil && hadOld {
oldObj = data.(Obj)
var old object
old, hadOld, err = txn.getTxn().Insert(t, Revision(0), obj)
if hadOld {
oldObj = old.data.(Obj)
}
return
}

func (t *genTable[Obj]) CompareAndSwap(txn WriteTxn, rev Revision, obj Obj) (oldObj Obj, hadOld bool, err error) {
var data any
data, hadOld, err = txn.getTxn().Insert(t, rev, obj)
if err == nil && hadOld {
oldObj = data.(Obj)
var old object
old, hadOld, err = txn.getTxn().Insert(t, rev, obj)
if hadOld {
oldObj = old.data.(Obj)
}
return
}

func (t *genTable[Obj]) Delete(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, err error) {
var data any
data, hadOld, err = txn.getTxn().Delete(t, Revision(0), obj)
if err == nil && hadOld {
oldObj = data.(Obj)
var old object
old, hadOld, err = txn.getTxn().Delete(t, Revision(0), obj)
if hadOld {
oldObj = old.data.(Obj)
}
return
}

func (t *genTable[Obj]) CompareAndDelete(txn WriteTxn, rev Revision, obj Obj) (oldObj Obj, hadOld bool, err error) {
var data any
data, hadOld, err = txn.getTxn().Delete(t, rev, obj)
if err == nil && hadOld {
oldObj = data.(Obj)
var old object
old, hadOld, err = txn.getTxn().Delete(t, rev, obj)
if hadOld {
oldObj = old.data.(Obj)
}
return
}
Expand Down
20 changes: 10 additions & 10 deletions txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ func (txn *txn) mustIndexWriteTxn(name TableName, index IndexName) indexWriteTxn
return indexTxn
}

func (txn *txn) Insert(meta TableMeta, guardRevision Revision, data any) (any, bool, error) {
func (txn *txn) Insert(meta TableMeta, guardRevision Revision, data any) (object, bool, error) {
if txn.rootReadTxn == nil {
return nil, false, ErrTransactionClosed
return object{}, false, ErrTransactionClosed
}

// Look up table and allocate a new revision.
tableName := meta.Name()
table, ok := txn.modifiedTables[tableName]
if !ok {
return nil, false, tableError(tableName, ErrTableNotLockedForWriting)
return object{}, false, tableError(tableName, ErrTableNotLockedForWriting)
}
oldRevision := table.revision
table.revision++
Expand All @@ -186,7 +186,7 @@ func (txn *txn) Insert(meta TableMeta, guardRevision Revision, data any) (any, b
// the insert.
idIndexTxn.txn.Delete(idKey)
table.revision = oldRevision
return nil, false, ErrObjectNotFound
return object{}, false, ErrObjectNotFound
}
if oldObj.revision != guardRevision {
// Revert the change. We're assuming here that it's rarer for CompareAndSwap() to
Expand Down Expand Up @@ -245,7 +245,7 @@ func (txn *txn) Insert(meta TableMeta, guardRevision Revision, data any) (any, b
})
}

return oldObj.data, oldExists, nil
return oldObj, oldExists, nil
}

func (txn *txn) hasDeleteTrackers(name TableName) bool {
Expand Down Expand Up @@ -277,16 +277,16 @@ func (txn *txn) addDeleteTracker(meta TableMeta, trackerName string, dt deleteTr

}

func (txn *txn) Delete(meta TableMeta, guardRevision Revision, data any) (any, bool, error) {
func (txn *txn) Delete(meta TableMeta, guardRevision Revision, data any) (object, bool, error) {
if txn.rootReadTxn == nil {
return nil, false, ErrTransactionClosed
return object{}, false, ErrTransactionClosed
}

// Look up table and allocate a new revision.
tableName := meta.Name()
table, ok := txn.modifiedTables[tableName]
if !ok {
return nil, false, tableError(tableName, ErrTableNotLockedForWriting)
return object{}, false, tableError(tableName, ErrTableNotLockedForWriting)
}
oldRevision := table.revision
table.revision++
Expand All @@ -299,7 +299,7 @@ func (txn *txn) Delete(meta TableMeta, guardRevision Revision, data any) (any, b
idIndexTree := txn.mustIndexWriteTxn(tableName, meta.primary().name)
obj, existed := idIndexTree.txn.Delete(idKey)
if !existed {
return nil, false, nil
return object{}, false, nil
}

// For CompareAndDelete() validate against guard revision and if there's a mismatch,
Expand Down Expand Up @@ -338,7 +338,7 @@ func (txn *txn) Delete(meta TableMeta, guardRevision Revision, data any) (any, b
txn.mustIndexWriteTxn(tableName, GraveyardRevisionIndex).txn.Insert(index.Uint64(revision), obj)
}

return obj.data, true, nil
return obj, true, nil
}

// encodeNonUniqueKey constructs the internal key to use with non-unique indexes.
Expand Down

0 comments on commit 10f9eb6

Please sign in to comment.