diff --git a/cli/client.go b/cli/client.go index 532712e8f8..475f83a80a 100644 --- a/cli/client.go +++ b/cli/client.go @@ -31,7 +31,7 @@ Execute queries, add schema types, obtain node info, etc.`, if err := setContextTransaction(cmd, txID); err != nil { return err } - return setContextStore(cmd) + return setContextDB(cmd) }, } cmd.PersistentFlags().Uint64Var(&txID, "tx", 0, "Transaction ID") diff --git a/cli/collection.go b/cli/collection.go index 23ef9194ae..5b682e5366 100644 --- a/cli/collection.go +++ b/cli/collection.go @@ -17,7 +17,6 @@ import ( "github.com/spf13/cobra" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) func MakeCollectionCommand() *cobra.Command { @@ -41,7 +40,7 @@ func MakeCollectionCommand() *cobra.Command { if err := setContextTransaction(cmd, txID); err != nil { return err } - if err := setContextStore(cmd); err != nil { + if err := setContextDB(cmd); err != nil { return err } store := mustGetContextStore(cmd) @@ -71,10 +70,6 @@ func MakeCollectionCommand() *cobra.Command { } col := cols[0] - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } - ctx := context.WithValue(cmd.Context(), colContextKey, col) cmd.SetContext(ctx) return nil diff --git a/cli/index_create.go b/cli/index_create.go index bfe5ec64c2..0d724da15b 100644 --- a/cli/index_create.go +++ b/cli/index_create.go @@ -14,7 +14,6 @@ import ( "github.com/spf13/cobra" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) func MakeIndexCreateCommand() *cobra.Command { @@ -52,9 +51,6 @@ Example: create a named index for 'Users' collection on 'name' field: if err != nil { return err } - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } desc, err = col.CreateIndex(cmd.Context(), desc) if err != nil { return err diff --git a/cli/index_drop.go b/cli/index_drop.go index 96f007268d..5dd069b5da 100644 --- a/cli/index_drop.go +++ b/cli/index_drop.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeIndexDropCommand() *cobra.Command { @@ -34,9 +32,6 @@ Example: drop the index 'UsersByName' for 'Users' collection: if err != nil { return err } - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } return col.DropIndex(cmd.Context(), nameArg) }, } diff --git a/cli/index_list.go b/cli/index_list.go index bf1fd21251..481acb7d37 100644 --- a/cli/index_list.go +++ b/cli/index_list.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeIndexListCommand() *cobra.Command { @@ -38,9 +36,6 @@ Example: show all index for 'Users' collection: if err != nil { return err } - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } indexes, err := col.GetIndexes(cmd.Context()) if err != nil { return err diff --git a/cli/schema_migration_down.go b/cli/schema_migration_down.go index 1d7622257c..a49f359694 100644 --- a/cli/schema_migration_down.go +++ b/cli/schema_migration_down.go @@ -17,8 +17,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeSchemaMigrationDownCommand() *cobra.Command { @@ -67,12 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - lens := store.LensRegistry() - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - lens = lens.WithTxn(tx) - } - - out, err := lens.MigrateDown(cmd.Context(), enumerable.New(src), collectionID) + out, err := store.LensRegistry().MigrateDown(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/cli/schema_migration_reload.go b/cli/schema_migration_reload.go index 4266b3ec3f..8ffb5542f1 100644 --- a/cli/schema_migration_reload.go +++ b/cli/schema_migration_reload.go @@ -12,8 +12,6 @@ package cli import ( "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeSchemaMigrationReloadCommand() *cobra.Command { @@ -23,12 +21,7 @@ func MakeSchemaMigrationReloadCommand() *cobra.Command { Long: `Reload the schema migrations within DefraDB`, RunE: func(cmd *cobra.Command, args []string) error { store := mustGetContextStore(cmd) - - lens := store.LensRegistry() - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - lens = lens.WithTxn(tx) - } - return lens.ReloadLenses(cmd.Context()) + return store.LensRegistry().ReloadLenses(cmd.Context()) }, } return cmd diff --git a/cli/schema_migration_up.go b/cli/schema_migration_up.go index 577b87d4c7..4473c45911 100644 --- a/cli/schema_migration_up.go +++ b/cli/schema_migration_up.go @@ -17,8 +17,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/spf13/cobra" - - "github.com/sourcenetwork/defradb/datastore" ) func MakeSchemaMigrationUpCommand() *cobra.Command { @@ -67,12 +65,7 @@ Example: migrate from stdin if err := json.Unmarshal(srcData, &src); err != nil { return err } - lens := store.LensRegistry() - if tx, ok := cmd.Context().Value(txContextKey).(datastore.Txn); ok { - lens = lens.WithTxn(tx) - } - - out, err := lens.MigrateUp(cmd.Context(), enumerable.New(src), collectionID) + out, err := store.LensRegistry().MigrateUp(cmd.Context(), enumerable.New(src), collectionID) if err != nil { return err } diff --git a/cli/utils.go b/cli/utils.go index f923021fcf..1df10a3409 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -21,7 +21,7 @@ import ( "github.com/spf13/viper" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/http" ) @@ -32,17 +32,8 @@ var ( cfgContextKey = contextKey("cfg") // rootDirContextKey is the context key for the root directory. rootDirContextKey = contextKey("rootDir") - // txContextKey is the context key for the datastore.Txn - // - // This will only be set if a transaction id is specified. - txContextKey = contextKey("tx") // dbContextKey is the context key for the client.DB dbContextKey = contextKey("db") - // storeContextKey is the context key for the client.Store - // - // If a transaction exists, all operations will be executed - // in the current transaction context. - storeContextKey = contextKey("store") // colContextKey is the context key for the client.Collection // // If a transaction exists, all operations will be executed @@ -61,7 +52,7 @@ func mustGetContextDB(cmd *cobra.Command) client.DB { // // If a store is not set in the current context this function panics. func mustGetContextStore(cmd *cobra.Command) client.Store { - return cmd.Context().Value(storeContextKey).(client.Store) + return cmd.Context().Value(dbContextKey).(client.Store) } // mustGetContextP2P returns the p2p implementation for the current command context. @@ -92,6 +83,18 @@ func tryGetContextCollection(cmd *cobra.Command) (client.Collection, bool) { return col, ok } +// setContextDB sets the db for the current command context. +func setContextDB(cmd *cobra.Command) error { + cfg := mustGetContextConfig(cmd) + db, err := http.NewClient(cfg.GetString("api.address")) + if err != nil { + return err + } + ctx := context.WithValue(cmd.Context(), dbContextKey, db) + cmd.SetContext(ctx) + return nil +} + // setContextConfig sets teh config for the current command context. func setContextConfig(cmd *cobra.Command) error { rootdir := mustGetContextRootDir(cmd) @@ -115,24 +118,7 @@ func setContextTransaction(cmd *cobra.Command, txId uint64) error { if err != nil { return err } - ctx := context.WithValue(cmd.Context(), txContextKey, tx) - cmd.SetContext(ctx) - return nil -} - -// setContextStore sets the store for the current command context. -func setContextStore(cmd *cobra.Command) error { - cfg := mustGetContextConfig(cmd) - db, err := http.NewClient(cfg.GetString("api.address")) - if err != nil { - return err - } - ctx := context.WithValue(cmd.Context(), dbContextKey, db) - if tx, ok := ctx.Value(txContextKey).(datastore.Txn); ok { - ctx = context.WithValue(ctx, storeContextKey, db.WithTxn(tx)) - } else { - ctx = context.WithValue(ctx, storeContextKey, db) - } + ctx := db.SetContextTxn(cmd.Context(), tx) cmd.SetContext(ctx) return nil } diff --git a/client/collection.go b/client/collection.go index aa219b3a74..bab61607a9 100644 --- a/client/collection.go +++ b/client/collection.go @@ -14,8 +14,6 @@ import ( "context" "github.com/sourcenetwork/immutable" - - "github.com/sourcenetwork/defradb/datastore" ) // Collection represents a defradb collection. @@ -192,10 +190,6 @@ type Collection interface { showDeleted bool, ) (*Document, error) - // WithTxn returns a new instance of the collection, with a transaction - // handle instead of a raw DB handle. - WithTxn(datastore.Txn) Collection - // GetAllDocIDs returns all the document IDs that exist in the collection. GetAllDocIDs(ctx context.Context, identity immutable.Option[string]) (<-chan DocIDResult, error) diff --git a/client/db.go b/client/db.go index a5d855f137..cedd63d492 100644 --- a/client/db.go +++ b/client/db.go @@ -42,9 +42,6 @@ type DB interface { // can safely operate on it concurrently. NewConcurrentTxn(context.Context, bool) (datastore.Txn, error) - // WithTxn returns a new [client.Store] that respects the given transaction. - WithTxn(datastore.Txn) Store - // Root returns the underlying root store, within which all data managed by DefraDB is held. Root() datastore.RootStore diff --git a/client/lens.go b/client/lens.go index 1a6b423991..3f5befc604 100644 --- a/client/lens.go +++ b/client/lens.go @@ -15,8 +15,6 @@ import ( "github.com/lens-vm/lens/host-go/config/model" "github.com/sourcenetwork/immutable/enumerable" - - "github.com/sourcenetwork/defradb/datastore" ) // LensConfig represents the configuration of a Lens migration in Defra. @@ -43,12 +41,6 @@ type LensConfig struct { // LensRegistry exposes several useful thread-safe migration related functions which may // be used to manage migrations. type LensRegistry interface { - // WithTxn returns a new LensRegistry scoped to the given transaction. - // - // WARNING: Currently this does not provide snapshot isolation, if other transactions are committed - // after this has been created, the results of those commits will be visible within this scope. - WithTxn(datastore.Txn) LensRegistry - // SetMigration caches the migration for the given collection ID. It does not persist the migration in long // term storage, for that one should call [Store.SetMigration(ctx, cfg)]. // diff --git a/db/backup.go b/db/backup.go index 2d3b824be1..17110bec05 100644 --- a/db/backup.go +++ b/db/backup.go @@ -92,7 +92,7 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.WithTxn(txn).Create(ctx, acpIdentity.NoIdentity, doc) + err = col.Create(ctx, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocCreate(err) } @@ -104,7 +104,7 @@ func (db *db) basicImport(ctx context.Context, txn datastore.Txn, filepath strin return NewErrDocUpdate(err) } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to backup - err = col.WithTxn(txn).Update(ctx, acpIdentity.NoIdentity, doc) + err = col.Update(ctx, acpIdentity.NoIdentity, doc) if err != nil { return NewErrDocUpdate(err) } @@ -191,9 +191,8 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client if err != nil { return err } - colTxn := col.WithTxn(txn) // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - docIDsCh, err := colTxn.GetAllDocIDs(ctx, acpIdentity.NoIdentity) + docIDsCh, err := col.GetAllDocIDs(ctx, acpIdentity.NoIdentity) if err != nil { return err } @@ -210,7 +209,7 @@ func (db *db) basicExport(ctx context.Context, txn datastore.Txn, config *client } } // TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2430 - Add identity ability to export - doc, err := colTxn.Get(ctx, acpIdentity.NoIdentity, docResultWithID.ID, false) + doc, err := col.Get(ctx, acpIdentity.NoIdentity, docResultWithID.ID, false) if err != nil { return err } diff --git a/db/collection.go b/db/collection.go index d7364df3b2..1afa1c775a 100644 --- a/db/collection.go +++ b/db/collection.go @@ -46,18 +46,8 @@ var _ client.Collection = (*collection)(nil) // collection stores data records at Documents, which are gathered // together under a collection name. This is analogous to SQL Tables. type collection struct { - db *db - - // txn represents any externally provided [datastore.Txn] for which any - // operation on this [collection] instance should be scoped to. - // - // If this has no value, operations requiring a transaction should use an - // implicit internally managed transaction, which only lives for duration - // of the operation in question. - txn immutable.Option[datastore.Txn] - - def client.CollectionDefinition - + db *db + def client.CollectionDefinition indexes []CollectionIndex fetcherFactory func() fetcher.Fetcher } @@ -1240,11 +1230,10 @@ func (c *collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], ) (<-chan client.DocIDResult, error) { - txn, err := c.getTxn(ctx, true) + ctx, txn, err := ensureContextTxn(ctx, c.db, true) if err != nil { return nil, err } - return c.getAllDocIDsChan(ctx, identity, txn) } @@ -1271,7 +1260,7 @@ func (c *collection) getAllDocIDsChan( log.ErrorContextE(ctx, errFailedtoCloseQueryReqAllIDs, err) } close(resCh) - c.discardImplicitTxn(ctx, txn) + txn.Discard(ctx) }() for res := range q.Next() { // check for Done on context first @@ -1351,18 +1340,6 @@ func (c *collection) Definition() client.CollectionDefinition { return c.def } -// WithTxn returns a new instance of the collection, with a transaction -// handle instead of a raw DB handle. -func (c *collection) WithTxn(txn datastore.Txn) client.Collection { - return &collection{ - db: c.db, - txn: immutable.Some(txn), - def: c.def, - indexes: c.indexes, - fetcherFactory: c.fetcherFactory, - } -} - // Create a new document. // Will verify the DocID/CID to ensure that the new document is correctly formatted. func (c *collection) Create( @@ -1370,18 +1347,18 @@ func (c *collection) Create( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.create(ctx, identity, txn, doc) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } // CreateMany creates a collection of documents at once. @@ -1391,11 +1368,11 @@ func (c *collection) CreateMany( identity immutable.Option[string], docs []*client.Document, ) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) for _, doc := range docs { err = c.create(ctx, identity, txn, doc) @@ -1403,7 +1380,7 @@ func (c *collection) CreateMany( return err } } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) getDocIDAndPrimaryKeyFromDoc( @@ -1476,11 +1453,11 @@ func (c *collection) Update( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) @@ -1499,7 +1476,7 @@ func (c *collection) Update( return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } // Contract: DB Exists check is already performed, and a doc with the given ID exists. @@ -1541,11 +1518,11 @@ func (c *collection) Save( identity immutable.Option[string], doc *client.Document, ) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) // Check if document already exists with primary DS key. primaryKey := c.getPrimaryKeyFromDocID(doc.ID()) @@ -1567,7 +1544,7 @@ func (c *collection) Save( return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } // save saves the document state. save MUST not be called outside the `c.create` @@ -1823,11 +1800,11 @@ func (c *collection) Delete( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return false, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) @@ -1835,7 +1812,7 @@ func (c *collection) Delete( if err != nil { return false, err } - return true, c.commitImplicitTxn(ctx, txn) + return true, txn.Commit(ctx) } // Exists checks if a given document exists with supplied DocID. @@ -1844,18 +1821,18 @@ func (c *collection) Exists( identity immutable.Option[string], docID client.DocID, ) (bool, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return false, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) exists, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) if err != nil && !errors.Is(err, ds.ErrNotFound) { return false, err } - return exists && !isDeleted, c.commitImplicitTxn(ctx, txn) + return exists && !isDeleted, txn.Commit(ctx) } // check if a document exists with the given primary key @@ -1916,35 +1893,6 @@ func (c *collection) saveCompositeToMerkleCRDT( return merkleCRDT.Save(ctx, links) } -// getTxn gets or creates a new transaction from the underlying db. -// If the collection already has a txn, return the existing one. -// Otherwise, create a new implicit transaction. -func (c *collection) getTxn(ctx context.Context, readonly bool) (datastore.Txn, error) { - if c.txn.HasValue() { - return c.txn.Value(), nil - } - return c.db.NewTxn(ctx, readonly) -} - -// discardImplicitTxn is a proxy function used by the collection to execute the Discard() -// transaction function only if its an implicit transaction. -// -// Implicit transactions are transactions that are created *during* an operation execution as a side effect. -// -// Explicit transactions are provided to the collection object via the "WithTxn(...)" function. -func (c *collection) discardImplicitTxn(ctx context.Context, txn datastore.Txn) { - if !c.txn.HasValue() { - txn.Discard(ctx) - } -} - -func (c *collection) commitImplicitTxn(ctx context.Context, txn datastore.Txn) error { - if !c.txn.HasValue() { - return txn.Commit(ctx) - } - return nil -} - func (c *collection) getPrimaryKeyFromDocID(docID client.DocID) core.PrimaryDataStoreKey { return core.PrimaryDataStoreKey{ CollectionRootID: c.Description().RootID, diff --git a/db/collection_delete.go b/db/collection_delete.go index 984cd27a21..8d5bf3f2bb 100644 --- a/db/collection_delete.go +++ b/db/collection_delete.go @@ -54,12 +54,11 @@ func (c *collection) DeleteWithDocID( identity immutable.Option[string], docID client.DocID, ) (*client.DeleteResult, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) dsKey := c.getPrimaryKeyFromDocID(docID) res, err := c.deleteWithKey(ctx, identity, txn, dsKey) @@ -67,7 +66,7 @@ func (c *collection) DeleteWithDocID( return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // DeleteWithDocIDs is the same as DeleteWithDocID but accepts multiple DocIDs as a slice. @@ -76,19 +75,18 @@ func (c *collection) DeleteWithDocIDs( identity immutable.Option[string], docIDs []client.DocID, ) (*client.DeleteResult, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) res, err := c.deleteWithIDs(ctx, identity, txn, docIDs, client.Deleted) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // DeleteWithFilter deletes using a filter to target documents for delete. @@ -97,19 +95,18 @@ func (c *collection) DeleteWithFilter( identity immutable.Option[string], filter any, ) (*client.DeleteResult, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) res, err := c.deleteWithFilter(ctx, identity, txn, filter, client.Deleted) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } func (c *collection) deleteWithKey( diff --git a/db/collection_get.go b/db/collection_get.go index 16d5bd4711..8ae0dcae75 100644 --- a/db/collection_get.go +++ b/db/collection_get.go @@ -29,11 +29,11 @@ func (c *collection) Get( showDeleted bool, ) (*client.Document, error) { // create txn - txn, err := c.getTxn(ctx, true) + ctx, txn, err := ensureContextTxn(ctx, c.db, true) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) primaryKey := c.getPrimaryKeyFromDocID(docID) found, isDeleted, err := c.exists(ctx, identity, txn, primaryKey) @@ -53,7 +53,7 @@ func (c *collection) Get( return nil, client.ErrDocumentNotFoundOrNotAuthorized } - return doc, c.commitImplicitTxn(ctx, txn) + return doc, txn.Commit(ctx) } func (c *collection) get( diff --git a/db/collection_index.go b/db/collection_index.go index 1a7af8cc25..3e33c94709 100644 --- a/db/collection_index.go +++ b/db/collection_index.go @@ -41,7 +41,7 @@ func (db *db) createCollectionIndex( if err != nil { return client.IndexDescription{}, NewErrCanNotReadCollection(collectionName, err) } - col = col.WithTxn(txn) + ctx = SetContextTxn(ctx, txn) return col.CreateIndex(ctx, desc) } @@ -54,7 +54,7 @@ func (db *db) dropCollectionIndex( if err != nil { return NewErrCanNotReadCollection(collectionName, err) } - col = col.WithTxn(txn) + ctx = SetContextTxn(ctx, txn) return col.DropIndex(ctx, indexName) } @@ -112,26 +112,26 @@ func (db *db) fetchCollectionIndexDescriptions( } func (c *collection) CreateDocIndex(ctx context.Context, doc *client.Document) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.indexNewDoc(ctx, txn, doc) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client.Document) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, oldDoc) if err != nil { @@ -142,22 +142,22 @@ func (c *collection) UpdateDocIndex(ctx context.Context, oldDoc, newDoc *client. return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) DeleteDocIndex(ctx context.Context, doc *client.Document) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.deleteIndexedDoc(ctx, txn, doc) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) indexNewDoc(ctx context.Context, txn datastore.Txn, doc *client.Document) error { @@ -242,17 +242,17 @@ func (c *collection) CreateIndex( ctx context.Context, desc client.IndexDescription, ) (client.IndexDescription, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return client.IndexDescription{}, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) index, err := c.createIndex(ctx, txn, desc) if err != nil { return client.IndexDescription{}, err } - return index.Description(), c.commitImplicitTxn(ctx, txn) + return index.Description(), txn.Commit(ctx) } func (c *collection) createIndex( @@ -398,17 +398,17 @@ func (c *collection) indexExistingDocs( // // All index artifacts for existing documents related the index will be removed. func (c *collection) DropIndex(ctx context.Context, indexName string) error { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.dropIndex(ctx, txn, indexName) if err != nil { return err } - return c.commitImplicitTxn(ctx, txn) + return txn.Commit(ctx) } func (c *collection) dropIndex(ctx context.Context, txn datastore.Txn, indexName string) error { @@ -486,11 +486,11 @@ func (c *collection) loadIndexes(ctx context.Context, txn datastore.Txn) error { // GetIndexes returns all indexes for the collection. func (c *collection) GetIndexes(ctx context.Context) ([]client.IndexDescription, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) err = c.loadIndexes(ctx, txn) if err != nil { diff --git a/db/collection_update.go b/db/collection_update.go index dcc3ba6cba..e9ab2e7fa1 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -57,16 +57,17 @@ func (c *collection) UpdateWithFilter( filter any, updater string, ) (*client.UpdateResult, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) + res, err := c.updateWithFilter(ctx, identity, txn, filter, updater) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // UpdateWithDocID updates using a DocID to target a single document for update. @@ -78,17 +79,18 @@ func (c *collection) UpdateWithDocID( docID client.DocID, updater string, ) (*client.UpdateResult, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) + res, err := c.updateWithDocID(ctx, identity, txn, docID, updater) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } // UpdateWithDocIDs is the same as UpdateWithDocID but accepts multiple DocIDs as a slice. @@ -100,17 +102,18 @@ func (c *collection) UpdateWithDocIDs( docIDs []client.DocID, updater string, ) (*client.UpdateResult, error) { - txn, err := c.getTxn(ctx, false) + ctx, txn, err := ensureContextTxn(ctx, c.db, false) if err != nil { return nil, err } - defer c.discardImplicitTxn(ctx, txn) + defer txn.Discard(ctx) + res, err := c.updateWithIDs(ctx, identity, txn, docIDs, updater) if err != nil { return nil, err } - return res, c.commitImplicitTxn(ctx, txn) + return res, txn.Commit(ctx) } func (c *collection) updateWithDocID( @@ -333,7 +336,6 @@ func (c *collection) patchPrimaryDoc( if err != nil { return err } - primaryCol = primaryCol.WithTxn(txn) primarySchema := primaryCol.Schema() primaryField, ok := primaryCol.Description().GetFieldByRelation( @@ -439,7 +441,7 @@ func (c *collection) makeSelectionPlan( ctx, identity, c.db.acp, - c.db.WithTxn(txn), + c.db, txn, ) diff --git a/db/context.go b/db/context.go new file mode 100644 index 0000000000..d39472ea5a --- /dev/null +++ b/db/context.go @@ -0,0 +1,68 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package db + +import ( + "context" + + "github.com/sourcenetwork/defradb/datastore" +) + +// txnContextKey is the key type for transaction context values. +type txnContextKey struct{} + +// explicitTxn is a transaction that is managed outside of a db operation. +type explicitTxn struct { + datastore.Txn +} + +func (t *explicitTxn) Commit(ctx context.Context) error { + return nil // do nothing +} + +func (t *explicitTxn) Discard(ctx context.Context) { + // do nothing +} + +// transactionDB is a db that can create transactions. +type transactionDB interface { + NewTxn(context.Context, bool) (datastore.Txn, error) +} + +// ensureContextTxn ensures that the returned context has a transaction. +// +// If a transactions exists on the context it will be made explicit, +// otherwise a new implicit transaction will be created. +func ensureContextTxn(ctx context.Context, db transactionDB, readOnly bool) (context.Context, datastore.Txn, error) { + txn, ok := TryGetContextTxn(ctx) + if ok { + return SetContextTxn(ctx, &explicitTxn{txn}), &explicitTxn{txn}, nil + } + txn, err := db.NewTxn(ctx, readOnly) + if err != nil { + return nil, txn, err + } + return SetContextTxn(ctx, txn), txn, nil +} + +// TryGetContextTxn returns a transaction and a bool indicating if the +// txn was retrieved from the given context. +func TryGetContextTxn(ctx context.Context) (datastore.Txn, bool) { + txn, ok := ctx.Value(txnContextKey{}).(datastore.Txn) + return txn, ok +} + +// SetContextTxn returns a new context with the txn value set. +// +// This will overwrite any previously set transaction value. +func SetContextTxn(ctx context.Context, txn datastore.Txn) context.Context { + return context.WithValue(ctx, txnContextKey{}, txn) +} diff --git a/db/context_test.go b/db/context_test.go new file mode 100644 index 0000000000..c8b1a322e5 --- /dev/null +++ b/db/context_test.go @@ -0,0 +1,57 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEnsureContextTxnExplicit(t *testing.T) { + ctx := context.Background() + + db, err := newMemoryDB(ctx) + require.NoError(t, err) + + txn, err := db.NewTxn(ctx, true) + require.NoError(t, err) + + // set an explicit transaction + ctx = SetContextTxn(ctx, txn) + + ctx, txn, err = ensureContextTxn(ctx, db, true) + require.NoError(t, err) + + _, ok := txn.(*explicitTxn) + assert.True(t, ok) + + _, ok = ctx.Value(txnContextKey{}).(*explicitTxn) + assert.True(t, ok) +} + +func TestEnsureContextTxnImplicit(t *testing.T) { + ctx := context.Background() + + db, err := newMemoryDB(ctx) + require.NoError(t, err) + + ctx, txn, err := ensureContextTxn(ctx, db, true) + require.NoError(t, err) + + _, ok := txn.(*explicitTxn) + assert.False(t, ok) + + _, ok = ctx.Value(txnContextKey{}).(*explicitTxn) + assert.False(t, ok) +} diff --git a/db/db.go b/db/db.go index 239b26f9a7..e7a6fa8d09 100644 --- a/db/db.go +++ b/db/db.go @@ -89,7 +89,7 @@ func newDB( ctx context.Context, rootstore datastore.RootStore, options ...Option, -) (*implicitTxnDB, error) { +) (*db, error) { multistore := datastore.MultiStoreFrom(rootstore) parser, err := graphql.NewParser() @@ -119,7 +119,7 @@ func newDB( return nil, err } - return &implicitTxnDB{db}, nil + return db, nil } // NewTxn creates a new transaction. @@ -134,15 +134,6 @@ func (db *db) NewConcurrentTxn(ctx context.Context, readonly bool) (datastore.Tx return datastore.NewConcurrentTxnFrom(ctx, db.rootstore, txnId, readonly) } -// WithTxn returns a new [client.Store] that respects the given transaction. -func (db *db) WithTxn(txn datastore.Txn) client.Store { - return &explicitTxnDB{ - db: db, - txn: txn, - lensRegistry: db.lensRegistry.WithTxn(txn), - } -} - // Root returns the root datastore. func (db *db) Root() datastore.RootStore { return db.rootstore diff --git a/db/db_test.go b/db/db_test.go index 237a1f21ed..118adb285b 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -19,7 +19,7 @@ import ( badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" ) -func newMemoryDB(ctx context.Context) (*implicitTxnDB, error) { +func newMemoryDB(ctx context.Context) (*db, error) { opts := badgerds.Options{Options: badger.DefaultOptions("").WithInMemory(true)} rootstore, err := badgerds.NewDatastore("", &opts) if err != nil { diff --git a/db/index_test.go b/db/index_test.go index 44c2e45f52..aeda2bdd6d 100644 --- a/db/index_test.go +++ b/db/index_test.go @@ -53,7 +53,7 @@ const ( type indexTestFixture struct { ctx context.Context - db *implicitTxnDB + db *db txn datastore.Txn users client.Collection t *testing.T @@ -784,7 +784,8 @@ func TestCollectionGetIndexes_ShouldCloseQueryIterator(t *testing.T) { mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(queryResults, nil) - _, err := f.users.WithTxn(mockedTxn).GetIndexes(f.ctx) + ctx := SetContextTxn(f.ctx, mockedTxn) + _, err := f.users.GetIndexes(ctx) assert.NoError(t, err) } @@ -840,7 +841,8 @@ func TestCollectionGetIndexes_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - _, err := f.users.WithTxn(mockedTxn).GetIndexes(f.ctx) + ctx := SetContextTxn(f.ctx, mockedTxn) + _, err := f.users.GetIndexes(ctx) require.ErrorIs(t, err, testCase.ExpectedError) } } @@ -902,7 +904,8 @@ func TestCollectionGetIndexes_IfStoredIndexWithUnsupportedType_ReturnError(t *te mockedTxn.MockSystemstore.EXPECT().Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, indexDescData), nil) - _, err = collection.WithTxn(mockedTxn).GetIndexes(f.ctx) + ctx := SetContextTxn(f.ctx, mockedTxn) + _, err = collection.GetIndexes(ctx) require.ErrorIs(t, err, NewErrUnsupportedIndexFieldType(unsupportedKind)) } @@ -1093,17 +1096,18 @@ func TestDropIndex_IfFailsToDeleteFromStorage_ReturnError(t *testing.T) { mockedTxn.MockDatastore.EXPECT().Query(mock.Anything, mock.Anything).Maybe(). Return(mocks.NewQueryResultsWithValues(t), nil) - err := f.users.WithTxn(mockedTxn).DropIndex(f.ctx, testUsersColIndexName) + ctx := SetContextTxn(f.ctx, mockedTxn) + err := f.users.DropIndex(ctx, testUsersColIndexName) require.ErrorIs(t, err, testErr) } func TestDropIndex_ShouldUpdateCollectionsDescription(t *testing.T) { f := newIndexTestFixture(t) defer f.db.Close() - col := f.users.WithTxn(f.txn) - _, err := col.CreateIndex(f.ctx, getUsersIndexDescOnName()) + ctx := SetContextTxn(f.ctx, f.txn) + _, err := f.users.CreateIndex(ctx, getUsersIndexDescOnName()) require.NoError(t, err) - indOnAge, err := col.CreateIndex(f.ctx, getUsersIndexDescOnAge()) + indOnAge, err := f.users.CreateIndex(ctx, getUsersIndexDescOnAge()) require.NoError(t, err) f.commitTxn() @@ -1144,7 +1148,8 @@ func TestDropIndex_IfSystemStoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Systemstore().Unset() mockedTxn.EXPECT().Systemstore().Return(mockedTxn.MockSystemstore).Maybe() - err := f.users.WithTxn(mockedTxn).DropIndex(f.ctx, testUsersColIndexName) + ctx := SetContextTxn(f.ctx, mockedTxn) + err := f.users.DropIndex(ctx, testUsersColIndexName) require.ErrorIs(t, err, testErr) } diff --git a/db/indexed_docs_test.go b/db/indexed_docs_test.go index c11eb2617f..70604fdc1f 100644 --- a/db/indexed_docs_test.go +++ b/db/indexed_docs_test.go @@ -322,7 +322,8 @@ func TestNonUnique_IfFailsToStoredIndexedDoc_Error(t *testing.T) { dataStoreOn.Put(mock.Anything, key.ToDS(), mock.Anything).Return(errors.New("error")) dataStoreOn.Put(mock.Anything, mock.Anything, mock.Anything).Return(nil) - err := f.users.WithTxn(mockTxn).Create(f.ctx, acpIdentity.NoIdentity, doc) + ctx := SetContextTxn(f.ctx, mockTxn) + err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(f.t, err, NewErrFailedToStoreIndexedField("name", nil)) } @@ -360,7 +361,8 @@ func TestNonUnique_IfSystemStorageHasInvalidIndexDescription_Error(t *testing.T) systemStoreOn.Query(mock.Anything, mock.Anything). Return(mocks.NewQueryResultsWithValues(t, []byte("invalid")), nil) - err := f.users.WithTxn(mockTxn).Create(f.ctx, acpIdentity.NoIdentity, doc) + ctx := SetContextTxn(f.ctx, mockTxn) + err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) assert.ErrorIs(t, err, datastore.NewErrInvalidStoredValue(nil)) } @@ -378,7 +380,8 @@ func TestNonUnique_IfSystemStorageFailsToReadIndexDesc_Error(t *testing.T) { systemStoreOn.Query(mock.Anything, mock.Anything). Return(nil, testErr) - err := f.users.WithTxn(mockTxn).Create(f.ctx, acpIdentity.NoIdentity, doc) + ctx := SetContextTxn(f.ctx, mockTxn) + err := f.users.Create(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -806,7 +809,8 @@ func TestNonUniqueUpdate_IfFailsToReadIndexDescription_ReturnError(t *testing.T) usersCol.(*collection).fetcherFactory = func() fetcher.Fetcher { return fetcherMocks.NewStubbedFetcher(t) } - err = usersCol.WithTxn(mockedTxn).Update(f.ctx, acpIdentity.NoIdentity, doc) + ctx := SetContextTxn(f.ctx, mockedTxn) + err = usersCol.Update(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } @@ -1048,7 +1052,8 @@ func TestNonUniqueUpdate_IfDatastoreFails_ReturnError(t *testing.T) { mockedTxn.EXPECT().Datastore().Unset() mockedTxn.EXPECT().Datastore().Return(mockedTxn.MockDatastore).Maybe() - err = f.users.WithTxn(mockedTxn).Update(f.ctx, acpIdentity.NoIdentity, doc) + ctx := SetContextTxn(f.ctx, mockedTxn) + err = f.users.Update(ctx, acpIdentity.NoIdentity, doc) require.ErrorIs(t, err, testErr) } } diff --git a/db/request.go b/db/request.go index 2905ee4de2..69b300f482 100644 --- a/db/request.go +++ b/db/request.go @@ -59,7 +59,7 @@ func (db *db) execRequest( ctx, identity, db.acp, - db.WithTxn(txn), + db, txn, ) diff --git a/db/store.go b/db/store.go new file mode 100644 index 0000000000..aff11f851d --- /dev/null +++ b/db/store.go @@ -0,0 +1,275 @@ +// Copyright 2024 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package db + +import ( + "context" + + "github.com/lens-vm/lens/host-go/config/model" + + "github.com/sourcenetwork/immutable" + + "github.com/sourcenetwork/defradb/client" +) + +// ExecRequest executes a request against the database. +func (db *db) ExecRequest( + ctx context.Context, + identity immutable.Option[string], + request string, +) *client.RequestResult { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + res := &client.RequestResult{} + res.GQL.Errors = []error{err} + return res + } + defer txn.Discard(ctx) + + res := db.execRequest(ctx, identity, request, txn) + if len(res.GQL.Errors) > 0 { + return res + } + + if err := txn.Commit(ctx); err != nil { + res.GQL.Errors = []error{err} + return res + } + + return res +} + +// GetCollectionByName returns an existing collection within the database. +func (db *db) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { + ctx, txn, err := ensureContextTxn(ctx, db, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return db.getCollectionByName(ctx, txn, name) +} + +// GetCollections gets all the currently defined collections. +func (db *db) GetCollections( + ctx context.Context, + options client.CollectionFetchOptions, +) ([]client.Collection, error) { + ctx, txn, err := ensureContextTxn(ctx, db, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return db.getCollections(ctx, txn, options) +} + +// GetSchemaByVersionID returns the schema description for the schema version of the +// ID provided. +// +// Will return an error if it is not found. +func (db *db) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { + ctx, txn, err := ensureContextTxn(ctx, db, true) + if err != nil { + return client.SchemaDescription{}, err + } + defer txn.Discard(ctx) + + return db.getSchemaByVersionID(ctx, txn, versionID) +} + +// GetSchemas returns all schema versions that currently exist within +// this [Store]. +func (db *db) GetSchemas( + ctx context.Context, + options client.SchemaFetchOptions, +) ([]client.SchemaDescription, error) { + ctx, txn, err := ensureContextTxn(ctx, db, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return db.getSchemas(ctx, txn, options) +} + +// GetAllIndexes gets all the indexes in the database. +func (db *db) GetAllIndexes( + ctx context.Context, +) (map[client.CollectionName][]client.IndexDescription, error) { + ctx, txn, err := ensureContextTxn(ctx, db, true) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + return db.getAllIndexDescriptions(ctx, txn) +} + +// AddSchema takes the provided GQL schema in SDL format, and applies it to the database, +// creating the necessary collections, request types, etc. +// +// All schema types provided must not exist prior to calling this, and they may not reference existing +// types previously defined. +func (db *db) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + cols, err := db.addSchema(ctx, txn, schemaString) + if err != nil { + return nil, err + } + + if err := txn.Commit(ctx); err != nil { + return nil, err + } + return cols, nil +} + +// PatchSchema takes the given JSON patch string and applies it to the set of CollectionDescriptions +// present in the database. +// +// It will also update the GQL types used by the query system. It will error and not apply any of the +// requested, valid updates should the net result of the patch result in an invalid state. The +// individual operations defined in the patch do not need to result in a valid state, only the net result +// of the full patch. +// +// The collections (including the schema version ID) will only be updated if any changes have actually +// been made, if the net result of the patch matches the current persisted description then no changes +// will be applied. +func (db *db) PatchSchema( + ctx context.Context, + patchString string, + migration immutable.Option[model.Lens], + setAsDefaultVersion bool, +) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (db *db) PatchCollection( + ctx context.Context, + patchString string, +) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = db.patchCollection(ctx, txn, patchString) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (db *db) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = db.setActiveSchemaVersion(ctx, txn, schemaVersionID) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (db *db) SetMigration(ctx context.Context, cfg client.LensConfig) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = db.setMigration(ctx, txn, cfg) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +func (db *db) AddView( + ctx context.Context, + query string, + sdl string, + transform immutable.Option[model.Lens], +) ([]client.CollectionDefinition, error) { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + return nil, err + } + defer txn.Discard(ctx) + + defs, err := db.addView(ctx, txn, query, sdl, transform) + if err != nil { + return nil, err + } + + err = txn.Commit(ctx) + if err != nil { + return nil, err + } + + return defs, nil +} + +// BasicImport imports a json dataset. +// filepath must be accessible to the node. +func (db *db) BasicImport(ctx context.Context, filepath string) error { + ctx, txn, err := ensureContextTxn(ctx, db, false) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = db.basicImport(ctx, txn, filepath) + if err != nil { + return err + } + + return txn.Commit(ctx) +} + +// BasicExport exports the current data or subset of data to file in json format. +func (db *db) BasicExport(ctx context.Context, config *client.BackupConfig) error { + ctx, txn, err := ensureContextTxn(ctx, db, true) + if err != nil { + return err + } + defer txn.Discard(ctx) + + err = db.basicExport(ctx, txn, config) + if err != nil { + return err + } + + return txn.Commit(ctx) +} diff --git a/db/subscriptions.go b/db/subscriptions.go index f6f187c54f..e649769c18 100644 --- a/db/subscriptions.go +++ b/db/subscriptions.go @@ -62,8 +62,8 @@ func (db *db) handleSubscription( continue } + ctx := SetContextTxn(ctx, txn) db.handleEvent(ctx, identity, txn, pub, evt, r) - txn.Discard(ctx) } } @@ -80,7 +80,7 @@ func (db *db) handleEvent( ctx, identity, db.acp, - db.WithTxn(txn), + db, txn, ) diff --git a/db/txn_db.go b/db/txn_db.go deleted file mode 100644 index e77176b433..0000000000 --- a/db/txn_db.go +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright 2023 Democratized Data Foundation -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package db - -import ( - "context" - - "github.com/lens-vm/lens/host-go/config/model" - - "github.com/sourcenetwork/immutable" - - "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" -) - -var _ client.DB = (*implicitTxnDB)(nil) -var _ client.DB = (*explicitTxnDB)(nil) -var _ client.Store = (*implicitTxnDB)(nil) -var _ client.Store = (*explicitTxnDB)(nil) - -type implicitTxnDB struct { - *db -} - -type explicitTxnDB struct { - *db - txn datastore.Txn - lensRegistry client.LensRegistry -} - -// ExecRequest executes a request against the database. -func (db *implicitTxnDB) ExecRequest( - ctx context.Context, - identity immutable.Option[string], - request string, -) *client.RequestResult { - txn, err := db.NewTxn(ctx, false) - if err != nil { - res := &client.RequestResult{} - res.GQL.Errors = []error{err} - return res - } - defer txn.Discard(ctx) - - res := db.execRequest(ctx, identity, request, txn) - if len(res.GQL.Errors) > 0 { - return res - } - - if err := txn.Commit(ctx); err != nil { - res.GQL.Errors = []error{err} - return res - } - - return res -} - -// ExecRequest executes a transaction request against the database. -func (db *explicitTxnDB) ExecRequest( - ctx context.Context, - identity immutable.Option[string], - request string, -) *client.RequestResult { - return db.execRequest(ctx, identity, request, db.txn) -} - -// GetCollectionByName returns an existing collection within the database. -func (db *implicitTxnDB) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getCollectionByName(ctx, txn, name) -} - -// GetCollectionByName returns an existing collection within the database. -func (db *explicitTxnDB) GetCollectionByName(ctx context.Context, name string) (client.Collection, error) { - col, err := db.getCollectionByName(ctx, db.txn, name) - if err != nil { - return nil, err - } - - return col.WithTxn(db.txn), nil -} - -// GetCollections gets all the currently defined collections. -func (db *implicitTxnDB) GetCollections( - ctx context.Context, - options client.CollectionFetchOptions, -) ([]client.Collection, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getCollections(ctx, txn, options) -} - -// GetCollections gets all the currently defined collections. -func (db *explicitTxnDB) GetCollections( - ctx context.Context, - options client.CollectionFetchOptions, -) ([]client.Collection, error) { - cols, err := db.getCollections(ctx, db.txn, options) - if err != nil { - return nil, err - } - - for i := range cols { - cols[i] = cols[i].WithTxn(db.txn) - } - - return cols, nil -} - -// GetSchemaByVersionID returns the schema description for the schema version of the -// ID provided. -// -// Will return an error if it is not found. -func (db *implicitTxnDB) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return client.SchemaDescription{}, err - } - defer txn.Discard(ctx) - - return db.getSchemaByVersionID(ctx, txn, versionID) -} - -// GetSchemaByVersionID returns the schema description for the schema version of the -// ID provided. -// -// Will return an error if it is not found. -func (db *explicitTxnDB) GetSchemaByVersionID(ctx context.Context, versionID string) (client.SchemaDescription, error) { - return db.getSchemaByVersionID(ctx, db.txn, versionID) -} - -// GetSchemas returns all schema versions that currently exist within -// this [Store]. -func (db *implicitTxnDB) GetSchemas( - ctx context.Context, - options client.SchemaFetchOptions, -) ([]client.SchemaDescription, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getSchemas(ctx, txn, options) -} - -// GetSchemas returns all schema versions that currently exist within -// this [Store]. -func (db *explicitTxnDB) GetSchemas( - ctx context.Context, - options client.SchemaFetchOptions, -) ([]client.SchemaDescription, error) { - return db.getSchemas(ctx, db.txn, options) -} - -// GetAllIndexes gets all the indexes in the database. -func (db *implicitTxnDB) GetAllIndexes( - ctx context.Context, -) (map[client.CollectionName][]client.IndexDescription, error) { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - return db.getAllIndexDescriptions(ctx, txn) -} - -// GetAllIndexes gets all the indexes in the database. -func (db *explicitTxnDB) GetAllIndexes( - ctx context.Context, -) (map[client.CollectionName][]client.IndexDescription, error) { - return db.getAllIndexDescriptions(ctx, db.txn) -} - -// AddSchema takes the provided GQL schema in SDL format, and applies it to the database, -// creating the necessary collections, request types, etc. -// -// All schema types provided must not exist prior to calling this, and they may not reference existing -// types previously defined. -func (db *implicitTxnDB) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - cols, err := db.addSchema(ctx, txn, schemaString) - if err != nil { - return nil, err - } - - if err := txn.Commit(ctx); err != nil { - return nil, err - } - return cols, nil -} - -// AddSchema takes the provided GQL schema in SDL format, and applies it to the database, -// creating the necessary collections, request types, etc. -// -// All schema types provided must not exist prior to calling this, and they may not reference existing -// types previously defined. -func (db *explicitTxnDB) AddSchema(ctx context.Context, schemaString string) ([]client.CollectionDescription, error) { - return db.addSchema(ctx, db.txn, schemaString) -} - -// PatchSchema takes the given JSON patch string and applies it to the set of CollectionDescriptions -// present in the database. -// -// It will also update the GQL types used by the query system. It will error and not apply any of the -// requested, valid updates should the net result of the patch result in an invalid state. The -// individual operations defined in the patch do not need to result in a valid state, only the net result -// of the full patch. -// -// The collections (including the schema version ID) will only be updated if any changes have actually -// been made, if the net result of the patch matches the current persisted description then no changes -// will be applied. -func (db *implicitTxnDB) PatchSchema( - ctx context.Context, - patchString string, - migration immutable.Option[model.Lens], - setAsDefaultVersion bool, -) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.patchSchema(ctx, txn, patchString, migration, setAsDefaultVersion) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -// PatchSchema takes the given JSON patch string and applies it to the set of CollectionDescriptions -// present in the database. -// -// It will also update the GQL types used by the query system. It will error and not apply any of the -// requested, valid updates should the net result of the patch result in an invalid state. The -// individual operations defined in the patch do not need to result in a valid state, only the net result -// of the full patch. -// -// The collections (including the schema version ID) will only be updated if any changes have actually -// been made, if the net result of the patch matches the current persisted description then no changes -// will be applied. -func (db *explicitTxnDB) PatchSchema( - ctx context.Context, - patchString string, - migration immutable.Option[model.Lens], - setAsDefaultVersion bool, -) error { - return db.patchSchema(ctx, db.txn, patchString, migration, setAsDefaultVersion) -} - -func (db *implicitTxnDB) PatchCollection( - ctx context.Context, - patchString string, -) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.patchCollection(ctx, txn, patchString) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -func (db *explicitTxnDB) PatchCollection( - ctx context.Context, - patchString string, -) error { - return db.patchCollection(ctx, db.txn, patchString) -} - -func (db *implicitTxnDB) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.setActiveSchemaVersion(ctx, txn, schemaVersionID) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -func (db *explicitTxnDB) SetActiveSchemaVersion(ctx context.Context, schemaVersionID string) error { - return db.setActiveSchemaVersion(ctx, db.txn, schemaVersionID) -} - -func (db *implicitTxnDB) SetMigration(ctx context.Context, cfg client.LensConfig) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.setMigration(ctx, txn, cfg) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -func (db *explicitTxnDB) SetMigration(ctx context.Context, cfg client.LensConfig) error { - return db.setMigration(ctx, db.txn, cfg) -} - -func (db *implicitTxnDB) AddView( - ctx context.Context, - query string, - sdl string, - transform immutable.Option[model.Lens], -) ([]client.CollectionDefinition, error) { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return nil, err - } - defer txn.Discard(ctx) - - defs, err := db.addView(ctx, txn, query, sdl, transform) - if err != nil { - return nil, err - } - - err = txn.Commit(ctx) - if err != nil { - return nil, err - } - - return defs, nil -} - -func (db *explicitTxnDB) AddView( - ctx context.Context, - query string, - sdl string, - transform immutable.Option[model.Lens], -) ([]client.CollectionDefinition, error) { - return db.addView(ctx, db.txn, query, sdl, transform) -} - -// BasicImport imports a json dataset. -// filepath must be accessible to the node. -func (db *implicitTxnDB) BasicImport(ctx context.Context, filepath string) error { - txn, err := db.NewTxn(ctx, false) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.basicImport(ctx, txn, filepath) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -// BasicImport imports a json dataset. -// filepath must be accessible to the node. -func (db *explicitTxnDB) BasicImport(ctx context.Context, filepath string) error { - return db.basicImport(ctx, db.txn, filepath) -} - -// BasicExport exports the current data or subset of data to file in json format. -func (db *implicitTxnDB) BasicExport(ctx context.Context, config *client.BackupConfig) error { - txn, err := db.NewTxn(ctx, true) - if err != nil { - return err - } - defer txn.Discard(ctx) - - err = db.basicExport(ctx, txn, config) - if err != nil { - return err - } - - return txn.Commit(ctx) -} - -// BasicExport exports the current data or subset of data to file in json format. -func (db *explicitTxnDB) BasicExport(ctx context.Context, config *client.BackupConfig) error { - return db.basicExport(ctx, db.txn, config) -} - -// LensRegistry returns the LensRegistry in use by this database instance. -// -// It exposes several useful thread-safe migration related functions. -func (db *explicitTxnDB) LensRegistry() client.LensRegistry { - return db.lensRegistry -} diff --git a/http/client.go b/http/client.go index 69c5f2a503..8837ce2e2d 100644 --- a/http/client.go +++ b/http/client.go @@ -86,11 +86,6 @@ func (c *Client) NewConcurrentTxn(ctx context.Context, readOnly bool) (datastore return &Transaction{txRes.ID, c.http}, nil } -func (c *Client) WithTxn(tx datastore.Txn) client.Store { - client := c.http.withTxn(tx.ID()) - return &Client{client} -} - func (c *Client) BasicImport(ctx context.Context, filepath string) error { methodURL := c.http.baseURL.JoinPath("backup", "import") diff --git a/http/client_collection.go b/http/client_collection.go index c53bc7e7ff..39ede6aafc 100644 --- a/http/client_collection.go +++ b/http/client_collection.go @@ -25,7 +25,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" ) var _ client.Collection = (*Collection)(nil) @@ -445,13 +444,6 @@ func (c *Collection) Get( return doc, nil } -func (c *Collection) WithTxn(tx datastore.Txn) client.Collection { - return &Collection{ - http: c.http.withTxn(tx.ID()), - def: c.def, - } -} - func (c *Collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], diff --git a/http/client_lens.go b/http/client_lens.go index 9021aa31d6..34945a41d6 100644 --- a/http/client_lens.go +++ b/http/client_lens.go @@ -21,7 +21,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) var _ client.LensRegistry = (*LensRegistry)(nil) @@ -31,11 +30,6 @@ type LensRegistry struct { http *httpClient } -func (c *LensRegistry) WithTxn(tx datastore.Txn) client.LensRegistry { - http := c.http.withTxn(tx.ID()) - return &LensRegistry{http} -} - type setMigrationRequest struct { CollectionID uint32 Config model.Lens diff --git a/http/handler.go b/http/handler.go index 7cd278593b..e6d83dbdd3 100644 --- a/http/handler.go +++ b/http/handler.go @@ -54,7 +54,6 @@ func NewApiRouter() (*Router, error) { }) router.AddRouteGroup(func(r *Router) { - r.AddMiddleware(LensMiddleware) lens_handler.bindRoutes(r) }) @@ -82,7 +81,6 @@ func NewHandler(db client.DB) (*Handler, error) { r.Use( ApiMiddleware(db, txs), TransactionMiddleware, - StoreMiddleware, ) r.Handle("/*", router) }) diff --git a/http/handler_ccip.go b/http/handler_ccip.go index 36151c5cc3..dfe8a66083 100644 --- a/http/handler_ccip.go +++ b/http/handler_ccip.go @@ -35,7 +35,7 @@ type CCIPResponse struct { // ExecCCIP handles GraphQL over Cross Chain Interoperability Protocol requests. func (c *ccipHandler) ExecCCIP(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var ccipReq CCIPRequest switch req.Method { diff --git a/http/handler_collection.go b/http/handler_collection.go index 1f41442849..8b7f0cf64c 100644 --- a/http/handler_collection.go +++ b/http/handler_collection.go @@ -331,7 +331,7 @@ func (s *collectionHandler) CreateIndex(rw http.ResponseWriter, req *http.Reques } func (s *collectionHandler) GetIndexes(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) indexesMap, err := store.GetAllIndexes(req.Context()) if err != nil { diff --git a/http/handler_lens.go b/http/handler_lens.go index 532eaacefc..94ef9c2abe 100644 --- a/http/handler_lens.go +++ b/http/handler_lens.go @@ -22,9 +22,9 @@ import ( type lensHandler struct{} func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + store := req.Context().Value(dbContextKey).(client.Store) - err := lens.ReloadLenses(req.Context()) + err := store.LensRegistry().ReloadLenses(req.Context()) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -33,7 +33,7 @@ func (s *lensHandler) ReloadLenses(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + store := req.Context().Value(dbContextKey).(client.Store) var request setMigrationRequest if err := requestJSON(req, &request); err != nil { @@ -41,7 +41,7 @@ func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { return } - err := lens.SetMigration(req.Context(), request.CollectionID, request.Config) + err := store.LensRegistry().SetMigration(req.Context(), request.CollectionID, request.Config) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -50,7 +50,7 @@ func (s *lensHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + store := req.Context().Value(dbContextKey).(client.Store) var request migrateRequest if err := requestJSON(req, &request); err != nil { @@ -58,7 +58,7 @@ func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { return } - result, err := lens.MigrateUp(req.Context(), enumerable.New(request.Data), request.CollectionID) + result, err := store.LensRegistry().MigrateUp(req.Context(), enumerable.New(request.Data), request.CollectionID) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return @@ -75,7 +75,7 @@ func (s *lensHandler) MigrateUp(rw http.ResponseWriter, req *http.Request) { } func (s *lensHandler) MigrateDown(rw http.ResponseWriter, req *http.Request) { - lens := req.Context().Value(lensContextKey).(client.LensRegistry) + store := req.Context().Value(dbContextKey).(client.Store) var request migrateRequest if err := requestJSON(req, &request); err != nil { @@ -83,7 +83,7 @@ func (s *lensHandler) MigrateDown(rw http.ResponseWriter, req *http.Request) { return } - result, err := lens.MigrateDown(req.Context(), enumerable.New(request.Data), request.CollectionID) + result, err := store.LensRegistry().MigrateDown(req.Context(), enumerable.New(request.Data), request.CollectionID) if err != nil { responseJSON(rw, http.StatusBadRequest, errorResponse{err}) return diff --git a/http/handler_store.go b/http/handler_store.go index 4c57eda34f..c71e108818 100644 --- a/http/handler_store.go +++ b/http/handler_store.go @@ -27,7 +27,7 @@ import ( type storeHandler struct{} func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { @@ -43,7 +43,7 @@ func (s *storeHandler) BasicImport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var config client.BackupConfig if err := requestJSON(req, &config); err != nil { @@ -59,7 +59,7 @@ func (s *storeHandler) BasicExport(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) schema, err := io.ReadAll(req.Body) if err != nil { @@ -75,7 +75,7 @@ func (s *storeHandler) AddSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var message patchSchemaRequest err := requestJSON(req, &message) @@ -93,7 +93,7 @@ func (s *storeHandler) PatchSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var patch string err := requestJSON(req, &patch) @@ -111,7 +111,7 @@ func (s *storeHandler) PatchCollection(rw http.ResponseWriter, req *http.Request } func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) schemaVersionID, err := io.ReadAll(req.Body) if err != nil { @@ -127,7 +127,7 @@ func (s *storeHandler) SetActiveSchemaVersion(rw http.ResponseWriter, req *http. } func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var message addViewRequest err := requestJSON(req, &message) @@ -146,7 +146,7 @@ func (s *storeHandler) AddView(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var cfg client.LensConfig if err := requestJSON(req, &cfg); err != nil { @@ -163,7 +163,7 @@ func (s *storeHandler) SetMigration(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) options := client.CollectionFetchOptions{} if req.URL.Query().Has("name") { @@ -199,7 +199,7 @@ func (s *storeHandler) GetCollection(rw http.ResponseWriter, req *http.Request) } func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) options := client.SchemaFetchOptions{} if req.URL.Query().Has("version_id") { @@ -221,7 +221,7 @@ func (s *storeHandler) GetSchema(rw http.ResponseWriter, req *http.Request) { } func (s *storeHandler) GetAllIndexes(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) indexes, err := store.GetAllIndexes(req.Context()) if err != nil { @@ -296,7 +296,7 @@ func (res *GraphQLResponse) UnmarshalJSON(data []byte) error { } func (s *storeHandler) ExecRequest(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + store := req.Context().Value(dbContextKey).(client.Store) var request GraphQLRequest switch { diff --git a/http/http_client.go b/http/http_client.go index 13abb3c6d0..5bcda30dcd 100644 --- a/http/http_client.go +++ b/http/http_client.go @@ -17,12 +17,13 @@ import ( "net/http" "net/url" "strings" + + "github.com/sourcenetwork/defradb/db" ) type httpClient struct { client *http.Client baseURL *url.URL - txValue string } func newHttpClient(rawURL string) (*httpClient, error) { @@ -40,20 +41,13 @@ func newHttpClient(rawURL string) (*httpClient, error) { return &client, nil } -func (c *httpClient) withTxn(value uint64) *httpClient { - return &httpClient{ - client: c.client, - baseURL: c.baseURL, - txValue: fmt.Sprintf("%d", value), - } -} - func (c *httpClient) setDefaultHeaders(req *http.Request) { req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - if c.txValue != "" { - req.Header.Set(TX_HEADER_NAME, c.txValue) + txn, ok := db.TryGetContextTxn(req.Context()) + if ok { + req.Header.Set(TX_HEADER_NAME, fmt.Sprintf("%d", txn.ID())) } } diff --git a/http/middleware.go b/http/middleware.go index f18ba8bf60..674921fd73 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -23,6 +23,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" ) const TX_HEADER_NAME = "x-defradb-tx" @@ -34,20 +35,6 @@ var ( txsContextKey = contextKey("txs") // dbContextKey is the context key for the client.DB dbContextKey = contextKey("db") - // txContextKey is the context key for the datastore.Txn - // - // This will only be set if a transaction id is specified. - txContextKey = contextKey("tx") - // storeContextKey is the context key for the client.Store - // - // If a transaction exists, all operations will be executed - // in the current transaction context. - storeContextKey = contextKey("store") - // lensContextKey is the context key for the client.LensRegistry - // - // If a transaction exists, all operations will be executed - // in the current transaction context. - lensContextKey = contextKey("lens") // colContextKey is the context key for the client.Collection // // If a transaction exists, all operations will be executed @@ -102,42 +89,10 @@ func TransactionMiddleware(next http.Handler) http.Handler { next.ServeHTTP(rw, req) return } - - ctx := context.WithValue(req.Context(), txContextKey, tx) - next.ServeHTTP(rw, req.WithContext(ctx)) - }) -} - -// StoreMiddleware sets the db context for the current request. -func StoreMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - db := req.Context().Value(dbContextKey).(client.DB) - - var store client.Store - if tx, ok := req.Context().Value(txContextKey).(datastore.Txn); ok { - store = db.WithTxn(tx) - } else { - store = db + ctx := req.Context() + if val, ok := tx.(datastore.Txn); ok { + ctx = db.SetContextTxn(ctx, val) } - - ctx := context.WithValue(req.Context(), storeContextKey, store) - next.ServeHTTP(rw, req.WithContext(ctx)) - }) -} - -// LensMiddleware sets the lens context for the current request. -func LensMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) - - var lens client.LensRegistry - if tx, ok := req.Context().Value(txContextKey).(datastore.Txn); ok { - lens = store.LensRegistry().WithTxn(tx) - } else { - lens = store.LensRegistry() - } - - ctx := context.WithValue(req.Context(), lensContextKey, lens) next.ServeHTTP(rw, req.WithContext(ctx)) }) } @@ -145,18 +100,14 @@ func LensMiddleware(next http.Handler) http.Handler { // CollectionMiddleware sets the collection context for the current request. func CollectionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - store := req.Context().Value(storeContextKey).(client.Store) + db := req.Context().Value(dbContextKey).(client.DB) - col, err := store.GetCollectionByName(req.Context(), chi.URLParam(req, "name")) + col, err := db.GetCollectionByName(req.Context(), chi.URLParam(req, "name")) if err != nil { rw.WriteHeader(http.StatusNotFound) return } - if tx, ok := req.Context().Value(txContextKey).(datastore.Txn); ok { - col = col.WithTxn(tx) - } - ctx := context.WithValue(req.Context(), colContextKey, col) next.ServeHTTP(rw, req.WithContext(ctx)) }) diff --git a/net/peer_collection.go b/net/peer_collection.go index 4ef1139a1c..d8d27b361d 100644 --- a/net/peer_collection.go +++ b/net/peer_collection.go @@ -19,6 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" + "github.com/sourcenetwork/defradb/db" ) const marker = byte(0xff) @@ -33,8 +34,9 @@ func (p *Peer) AddP2PCollections(ctx context.Context, collectionIDs []string) er // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - storeCol, err := p.db.WithTxn(txn).GetCollections( - p.ctx, + ctx = db.SetContextTxn(ctx, txn) + storeCol, err := p.db.GetCollections( + ctx, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, @@ -112,8 +114,9 @@ func (p *Peer) RemoveP2PCollections(ctx context.Context, collectionIDs []string) // first let's make sure the collections actually exists storeCollections := []client.Collection{} for _, col := range collectionIDs { - storeCol, err := p.db.WithTxn(txn).GetCollections( - p.ctx, + ctx = db.SetContextTxn(ctx, txn) + storeCol, err := p.db.GetCollections( + ctx, client.CollectionFetchOptions{ SchemaRoot: immutable.Some(col), }, diff --git a/net/peer_replicator.go b/net/peer_replicator.go index 93f6070f0b..1dd3c47cf4 100644 --- a/net/peer_replicator.go +++ b/net/peer_replicator.go @@ -21,6 +21,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" + "github.com/sourcenetwork/defradb/db" ) func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { @@ -40,12 +41,15 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { return err } + // set transaction for all operations + ctx = db.SetContextTxn(ctx, txn) + var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.WithTxn(txn).GetCollectionByName(ctx, name) + col, err := p.db.GetCollectionByName(ctx, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -60,7 +64,7 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { default: // default to all collections (unless a collection contains a policy). // TODO-ACP: default to all collections after resolving https://github.com/sourcenetwork/defradb/issues/2366 - allCollections, err := p.db.WithTxn(txn).GetCollections(ctx, client.CollectionFetchOptions{}) + allCollections, err := p.db.GetCollections(ctx, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } @@ -109,7 +113,7 @@ func (p *Peer) SetReplicator(ctx context.Context, rep client.Replicator) error { // push all collection documents to the replicator peer for _, col := range added { // TODO-ACP: Support ACP <> P2P - https://github.com/sourcenetwork/defradb/issues/2366 - keysCh, err := col.WithTxn(txn).GetAllDocIDs(ctx, acpIdentity.NoIdentity) + keysCh, err := col.GetAllDocIDs(ctx, acpIdentity.NoIdentity) if err != nil { return NewErrReplicatorDocID(err, col.Name().Value(), rep.Info.ID) } @@ -136,12 +140,15 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro return err } + // set transaction for all operations + ctx = db.SetContextTxn(ctx, txn) + var collections []client.Collection switch { case len(rep.Schemas) > 0: // if specific collections are chosen get them by name for _, name := range rep.Schemas { - col, err := p.db.WithTxn(txn).GetCollectionByName(ctx, name) + col, err := p.db.GetCollectionByName(ctx, name) if err != nil { return NewErrReplicatorCollections(err) } @@ -156,7 +163,7 @@ func (p *Peer) DeleteReplicator(ctx context.Context, rep client.Replicator) erro default: // default to all collections - collections, err = p.db.WithTxn(txn).GetCollections(ctx, client.CollectionFetchOptions{}) + collections, err = p.db.GetCollections(ctx, client.CollectionFetchOptions{}) if err != nil { return NewErrReplicatorCollections(err) } diff --git a/net/server.go b/net/server.go index 58a9f16f75..73496559cf 100644 --- a/net/server.go +++ b/net/server.go @@ -33,6 +33,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/core" "github.com/sourcenetwork/defradb/datastore/badger/v4" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" pb "github.com/sourcenetwork/defradb/net/pb" ) @@ -250,11 +251,13 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL return nil, err } defer txn.Discard(ctx) - store := s.db.WithTxn(txn) + + // use a transaction for all operations + ctx = db.SetContextTxn(ctx, txn) // Currently a schema is the best way we have to link a push log request to a collection, // this will change with https://github.com/sourcenetwork/defradb/issues/1085 - col, err := s.getActiveCollection(ctx, store, string(req.Body.SchemaRoot)) + col, err := s.getActiveCollection(ctx, s.db, string(req.Body.SchemaRoot)) if err != nil { return nil, err } @@ -271,9 +274,9 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL return nil, errors.Wrap("failed to decode block to ipld.Node", err) } - var session sync.WaitGroup + var wg sync.WaitGroup bp := newBlockProcessor(s.peer, txn, col, dsKey, getter) - err = bp.processRemoteBlock(ctx, &session, nd, true) + err = bp.processRemoteBlock(ctx, &wg, nd, true) if err != nil { log.ErrorContextE( ctx, @@ -283,10 +286,10 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL corelog.Any("CID", cid), ) } - session.Wait() + wg.Wait() bp.mergeBlocks(ctx) - err = s.syncIndexedDocs(ctx, col.WithTxn(txn), docID) + err = s.syncIndexedDocs(ctx, col, docID) if err != nil { return nil, err } @@ -350,14 +353,12 @@ func (s *server) syncIndexedDocs( col client.Collection, docID client.DocID, ) error { - preTxnCol, err := s.db.GetCollectionByName(ctx, col.Name().Value()) - if err != nil { - return err - } + // remove transaction from old context + oldCtx := db.SetContextTxn(ctx, nil) //TODO-ACP: https://github.com/sourcenetwork/defradb/issues/2365 // Resolve while handling acp <> secondary indexes. - oldDoc, err := preTxnCol.Get(ctx, acpIdentity.NoIdentity, docID, false) + oldDoc, err := col.Get(oldCtx, acpIdentity.NoIdentity, docID, false) isNewDoc := errors.Is(err, client.ErrDocumentNotFoundOrNotAuthorized) if !isNewDoc && err != nil { return err @@ -372,7 +373,7 @@ func (s *server) syncIndexedDocs( } if isDeletedDoc { - return preTxnCol.DeleteDocIndex(ctx, oldDoc) + return col.DeleteDocIndex(oldCtx, oldDoc) } else if isNewDoc { return col.CreateDocIndex(ctx, doc) } else { diff --git a/planner/create.go b/planner/create.go index 3333ae999e..bedb1be5d5 100644 --- a/planner/create.go +++ b/planner/create.go @@ -78,7 +78,7 @@ func (n *createNode) Next() (bool, error) { return false, nil } - if err := n.collection.WithTxn(n.p.txn).Create( + if err := n.collection.Create( n.p.ctx, n.p.identity, n.doc, diff --git a/planner/delete.go b/planner/delete.go index 74bb14d202..87cf0994ac 100644 --- a/planner/delete.go +++ b/planner/delete.go @@ -140,7 +140,7 @@ func (p *Planner) DeleteDocs(parsed *mapper.Mutation) (planNode, error) { p: p, filter: parsed.Filter, docIDs: parsed.DocIDs.Value(), - collection: col.WithTxn(p.txn), + collection: col, source: slctNode, docMapper: docMapper{parsed.DocumentMapping}, }, nil diff --git a/planner/update.go b/planner/update.go index b86c616dbb..458094d4e0 100644 --- a/planner/update.go +++ b/planner/update.go @@ -169,7 +169,7 @@ func (p *Planner) UpdateDocs(parsed *mapper.Mutation) (planNode, error) { if err != nil { return nil, err } - update.collection = col.WithTxn(p.txn) + update.collection = col // create the results Select node resultsNode, err := p.Select(&parsed.Select) diff --git a/tests/bench/query/planner/utils.go b/tests/bench/query/planner/utils.go index 5bb4472840..caba91836d 100644 --- a/tests/bench/query/planner/utils.go +++ b/tests/bench/query/planner/utils.go @@ -57,11 +57,11 @@ func runMakePlanBench( fixture fixtures.Generator, query string, ) error { - db, _, err := benchutils.SetupDBAndCollections(b, ctx, fixture) + d, _, err := benchutils.SetupDBAndCollections(b, ctx, fixture) if err != nil { return err } - defer db.Close() + defer d.Close() parser, err := buildParser(ctx, fixture) if err != nil { @@ -73,18 +73,18 @@ func runMakePlanBench( if len(errs) > 0 { return errors.Wrap("failed to parse query string", errors.New(fmt.Sprintf("%v", errs))) } - txn, err := db.NewTxn(ctx, false) + txn, err := d.NewTxn(ctx, false) if err != nil { return errors.Wrap("failed to create txn", err) } - b.ResetTimer() + for i := 0; i < b.N; i++ { planner := planner.New( ctx, acpIdentity.NoIdentity, acp.NoACP, - db.WithTxn(txn), + d, txn, ) plan, err := planner.MakePlan(q) diff --git a/tests/clients/cli/wrapper.go b/tests/clients/cli/wrapper.go index d10188d4b2..2ddaf86137 100644 --- a/tests/clients/cli/wrapper.go +++ b/tests/clients/cli/wrapper.go @@ -406,7 +406,7 @@ func (w *Wrapper) ExecRequest( result := &client.RequestResult{} - stdOut, stdErr, err := w.cmd.executeStream(args) + stdOut, stdErr, err := w.cmd.executeStream(ctx, args) if err != nil { result.GQL.Errors = []error{err} return result @@ -515,13 +515,6 @@ func (w *Wrapper) NewConcurrentTxn(ctx context.Context, readOnly bool) (datastor return &Transaction{tx, w.cmd}, nil } -func (w *Wrapper) WithTxn(tx datastore.Txn) client.Store { - return &Wrapper{ - node: w.node, - cmd: w.cmd.withTxn(tx), - } -} - func (w *Wrapper) Root() datastore.RootStore { return w.node.Root() } diff --git a/tests/clients/cli/wrapper_cli.go b/tests/clients/cli/wrapper_cli.go index 2a985dcb18..9076605857 100644 --- a/tests/clients/cli/wrapper_cli.go +++ b/tests/clients/cli/wrapper_cli.go @@ -17,12 +17,11 @@ import ( "strings" "github.com/sourcenetwork/defradb/cli" - "github.com/sourcenetwork/defradb/datastore" + "github.com/sourcenetwork/defradb/db" ) type cliWrapper struct { address string - txValue string } func newCliWrapper(address string) *cliWrapper { @@ -31,15 +30,8 @@ func newCliWrapper(address string) *cliWrapper { } } -func (w *cliWrapper) withTxn(tx datastore.Txn) *cliWrapper { - return &cliWrapper{ - address: w.address, - txValue: fmt.Sprintf("%d", tx.ID()), - } -} - -func (w *cliWrapper) execute(_ context.Context, args []string) ([]byte, error) { - stdOut, stdErr, err := w.executeStream(args) +func (w *cliWrapper) execute(ctx context.Context, args []string) ([]byte, error) { + stdOut, stdErr, err := w.executeStream(ctx, args) if err != nil { return nil, err } @@ -57,12 +49,13 @@ func (w *cliWrapper) execute(_ context.Context, args []string) ([]byte, error) { return stdOutData, nil } -func (w *cliWrapper) executeStream(args []string) (io.ReadCloser, io.ReadCloser, error) { +func (w *cliWrapper) executeStream(ctx context.Context, args []string) (io.ReadCloser, io.ReadCloser, error) { stdOutRead, stdOutWrite := io.Pipe() stdErrRead, stdErrWrite := io.Pipe() - if w.txValue != "" { - args = append(args, "--tx", w.txValue) + tx, ok := db.TryGetContextTxn(ctx) + if ok { + args = append(args, "--tx", fmt.Sprintf("%d", tx.ID())) } args = append(args, "--url", w.address) diff --git a/tests/clients/cli/wrapper_collection.go b/tests/clients/cli/wrapper_collection.go index 9bb8fb9938..861606a2d1 100644 --- a/tests/clients/cli/wrapper_collection.go +++ b/tests/clients/cli/wrapper_collection.go @@ -20,7 +20,6 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/client/request" - "github.com/sourcenetwork/defradb/datastore" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/http" ) @@ -448,13 +447,6 @@ func (c *Collection) Get( return doc, nil } -func (c *Collection) WithTxn(tx datastore.Txn) client.Collection { - return &Collection{ - cmd: c.cmd.withTxn(tx), - def: c.def, - } -} - func (c *Collection) GetAllDocIDs( ctx context.Context, identity immutable.Option[string], @@ -466,7 +458,7 @@ func (c *Collection) GetAllDocIDs( args := []string{"client", "collection", "docIDs"} args = append(args, "--name", c.Description().Name.Value()) - stdOut, _, err := c.cmd.executeStream(args) + stdOut, _, err := c.cmd.executeStream(ctx, args) if err != nil { return nil, err } diff --git a/tests/clients/cli/wrapper_lens.go b/tests/clients/cli/wrapper_lens.go index da6011b9eb..a9f3e20bd1 100644 --- a/tests/clients/cli/wrapper_lens.go +++ b/tests/clients/cli/wrapper_lens.go @@ -20,7 +20,6 @@ import ( "github.com/sourcenetwork/immutable/enumerable" "github.com/sourcenetwork/defradb/client" - "github.com/sourcenetwork/defradb/datastore" ) var _ client.LensRegistry = (*LensRegistry)(nil) @@ -29,10 +28,6 @@ type LensRegistry struct { cmd *cliWrapper } -func (w *LensRegistry) WithTxn(tx datastore.Txn) client.LensRegistry { - return &LensRegistry{w.cmd.withTxn(tx)} -} - func (w *LensRegistry) SetMigration(ctx context.Context, collectionID uint32, config model.Lens) error { args := []string{"client", "schema", "migration", "set-registry"} diff --git a/tests/clients/http/wrapper.go b/tests/clients/http/wrapper.go index 415212b99c..51fe7ae66b 100644 --- a/tests/clients/http/wrapper.go +++ b/tests/clients/http/wrapper.go @@ -201,10 +201,6 @@ func (w *Wrapper) NewConcurrentTxn(ctx context.Context, readOnly bool) (datastor return &TxWrapper{server, client}, nil } -func (w *Wrapper) WithTxn(tx datastore.Txn) client.Store { - return w.client.WithTxn(tx) -} - func (w *Wrapper) Root() datastore.RootStore { return w.node.Root() } diff --git a/tests/integration/events/simple/with_create_txn_test.go b/tests/integration/events/simple/with_create_txn_test.go index 7ff1f838e7..f90fc96a88 100644 --- a/tests/integration/events/simple/with_create_txn_test.go +++ b/tests/integration/events/simple/with_create_txn_test.go @@ -19,6 +19,7 @@ import ( acpIdentity "github.com/sourcenetwork/defradb/acp/identity" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/db" testUtils "github.com/sourcenetwork/defradb/tests/integration/events" ) @@ -42,7 +43,9 @@ func TestEventsSimpleWithCreateWithTxnDiscarded(t *testing.T) { func(ctx context.Context, d client.DB) { txn, err := d.NewTxn(ctx, false) assert.Nil(t, err) - r := d.WithTxn(txn).ExecRequest( + + ctx = db.SetContextTxn(ctx, txn) + r := d.ExecRequest( ctx, acpIdentity.NoIdentity, `mutation { diff --git a/tests/integration/lens.go b/tests/integration/lens.go index 69c49a1cbc..541b708a33 100644 --- a/tests/integration/lens.go +++ b/tests/integration/lens.go @@ -14,6 +14,7 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/db" ) // ConfigureMigration is a test action which will configure a Lens migration using the @@ -42,9 +43,10 @@ func configureMigration( action ConfigureMigration, ) { for _, node := range getNodes(action.NodeID, s.nodes) { - db := getStore(s, node, action.TransactionID, action.ExpectedError) + txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) + ctx := db.SetContextTxn(s.ctx, txn) - err := db.SetMigration(s.ctx, action.LensConfig) + err := node.SetMigration(ctx, action.LensConfig) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) diff --git a/tests/integration/utils2.go b/tests/integration/utils2.go index 18c97e76d1..deb38acde3 100644 --- a/tests/integration/utils2.go +++ b/tests/integration/utils2.go @@ -32,6 +32,7 @@ import ( "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/datastore" badgerds "github.com/sourcenetwork/defradb/datastore/badger/v4" + "github.com/sourcenetwork/defradb/db" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/net" "github.com/sourcenetwork/defradb/request/graphql" @@ -1080,8 +1081,9 @@ func getCollections( action GetCollections, ) { for _, node := range getNodes(action.NodeID, s.nodes) { - db := getStore(s, node, action.TransactionID, "") - results, err := db.GetCollections(s.ctx, action.FilterOptions) + txn := getTransaction(s, node, action.TransactionID, "") + ctx := db.SetContextTxn(s.ctx, txn) + results, err := node.GetCollections(ctx, action.FilterOptions) expectedErrorRaised := AssertError(s.t, s.testCase.Description, err, action.ExpectedError) assertExpectedErrorRaised(s.t, s.testCase.Description, action.ExpectedError, expectedErrorRaised) @@ -1249,11 +1251,12 @@ func createDocViaGQL( input, ) - db := getStore(s, node, immutable.None[int](), action.ExpectedError) + txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) identity := acpIdentity.NewIdentity(action.Identity) - result := db.ExecRequest( - s.ctx, + ctx := db.SetContextTxn(s.ctx, txn) + result := node.ExecRequest( + ctx, identity, request, ) @@ -1426,10 +1429,10 @@ func updateDocViaGQL( input, ) - db := getStore(s, node, immutable.None[int](), action.ExpectedError) - - result := db.ExecRequest( - s.ctx, + txn := getTransaction(s, node, immutable.None[int](), action.ExpectedError) + ctx := db.SetContextTxn(s.ctx, txn) + result := node.ExecRequest( + ctx, acpIdentity.NewIdentity(action.Identity), request, ) @@ -1591,14 +1594,14 @@ func withRetry( return nil } -func getStore( +func getTransaction( s *state, db client.DB, transactionSpecifier immutable.Option[int], expectedError string, -) client.Store { +) datastore.Txn { if !transactionSpecifier.HasValue() { - return db + return nil } transactionID := transactionSpecifier.Value() @@ -1619,7 +1622,7 @@ func getStore( s.txns[transactionID] = txn } - return db.WithTxn(s.txns[transactionID]) + return s.txns[transactionID] } // commitTransaction commits the given transaction. @@ -1647,9 +1650,10 @@ func executeRequest( ) { var expectedErrorRaised bool for nodeID, node := range getNodes(action.NodeID, s.nodes) { - db := getStore(s, node, action.TransactionID, action.ExpectedError) - result := db.ExecRequest( - s.ctx, + txn := getTransaction(s, node, action.TransactionID, action.ExpectedError) + ctx := db.SetContextTxn(s.ctx, txn) + result := node.ExecRequest( + ctx, acpIdentity.NewIdentity(action.Identity), action.Request, )