From 73a2d09655d136880670bd41f1a254bffe77e975 Mon Sep 17 00:00:00 2001 From: Alexandros Filios Date: Fri, 20 Sep 2024 16:23:57 +0200 Subject: [PATCH] Multi updates Signed-off-by: Alexandros Filios --- .../common/core/generic/vault/inspector.go | 2 +- .../core/generic/vault/txidstore/cache.go | 12 ++ platform/common/core/generic/vault/vault.go | 145 ++++++++++-------- platform/common/driver/vault.go | 6 + platform/fabric/driver/iterator.go | 1 + .../view/services/db/driver/badger/badger.go | 4 +- platform/view/services/db/driver/driver.go | 4 +- .../db/driver/notifier/persistence.go | 4 +- .../services/db/driver/sql/common/base.go | 20 ++- .../db/driver/sql/common/versioned.go | 124 +++++++-------- .../services/db/driver/sql/postgres/base.go | 33 ++-- .../db/driver/sql/postgres/versioned.go | 4 +- 12 files changed, 210 insertions(+), 149 deletions(-) diff --git a/platform/common/core/generic/vault/inspector.go b/platform/common/core/generic/vault/inspector.go index b14d27cb9..949607734 100644 --- a/platform/common/core/generic/vault/inspector.go +++ b/platform/common/core/generic/vault/inspector.go @@ -65,7 +65,7 @@ func (i *Inspector) SetStateMetadata(driver.Namespace, driver.PKey, driver.Metad panic("programming error: the rwset inspector is read-only") } -func (i *Inspector) SetStateMetadatas(ns driver.Namespace, kvs map[driver.PKey]driver.Metadata, block driver.BlockNum, txnum driver.TxNum) map[driver.PKey]error { +func (i *Inspector) SetStateMetadatas(ns driver.Namespace, kvs map[driver.PKey]driver.VersionedMetadataValue) map[driver.PKey]error { panic("programming error: the rwset inspector is read-only") } diff --git a/platform/common/core/generic/vault/txidstore/cache.go b/platform/common/core/generic/vault/txidstore/cache.go index 82e0b8b4d..672eee153 100644 --- a/platform/common/core/generic/vault/txidstore/cache.go +++ b/platform/common/core/generic/vault/txidstore/cache.go @@ -31,6 +31,7 @@ type cache[V driver.ValidationCode] interface { type txidStore[V driver.ValidationCode] interface { Get(txID driver.TxID) (V, string, error) Set(txID driver.TxID, code V, message string) error + SetMultiple(txs []driver.ByNum[V]) error Iterator(pos interface{}) (collections.Iterator[*driver.ByNum[V]], error) } @@ -88,6 +89,17 @@ func (s *CachedStore[V]) Set(txID string, code V, message string) error { return nil } +func (s *CachedStore[V]) SetMultiple(txs []driver.ByNum[V]) error { + s.logger.Debugf("Set values for %d txs into backed and cache", len(txs)) + if err := s.backed.SetMultiple(txs); err != nil { + return err + } + for _, tx := range txs { + s.cache.Add(tx.TxID, &Entry[V]{ValidationCode: tx.Code, ValidationMessage: tx.Message}) + } + return nil +} + func (s *CachedStore[V]) Iterator(pos interface{}) (collections.Iterator[*driver.ByNum[V]], error) { return s.backed.Iterator(pos) } diff --git a/platform/common/core/generic/vault/vault.go b/platform/common/core/generic/vault/vault.go index ffe85c149..62e6bcff4 100644 --- a/platform/common/core/generic/vault/vault.go +++ b/platform/common/core/generic/vault/vault.go @@ -37,6 +37,7 @@ type TXIDStoreReader[V driver.ValidationCode] interface { type TXIDStore[V driver.ValidationCode] interface { TXIDStoreReader[V] Set(txID driver.TxID, code V, message string) error + SetMultiple(txs []driver.ByNum[V]) error Invalidate(txID driver.TxID) } @@ -289,94 +290,110 @@ func (db *Vault[V]) commitRWs(inputs ...commitInput) error { return errors.Wrapf(err, "begin update in store for txid %v failed", inputs) } - for _, input := range inputs { - span := trace.SpanFromContext(input.ctx) + if _, err := db.setStatuses(inputs, db.vcProvider.Busy()); err != nil { + return err + } - span.AddEvent("set_tx_busy") - if err := db.txIDStore.Set(input.txID, db.vcProvider.Busy(), ""); err != nil { - if !errors.HasCause(err, UniqueKeyViolation) { - return err + db.logger.Debugf("parse writes") + writes := make(map[driver.Namespace]map[driver.PKey]VersionedValue) + for _, input := range inputs { + for ns, ws := range input.rws.Writes { + vals := versionedValues(ws, input.block, input.indexInBloc) + if nsWrites, ok := writes[ns]; !ok { + writes[ns] = vals + } else { + collections.CopyMap(nsWrites, vals) } } + } - db.logger.Debugf("parse writes [%s]", input.txID) - span.AddEvent("store_writes") - if discarded, err := db.storeWrites(input.ctx, input.rws.Writes, input.block, input.indexInBloc); err != nil { - return errors.Wrapf(err, "failed storing writes") - } else if discarded { - db.logger.Infof("Discarded changes while storing writes as duplicates. Skipping...") + if errs := db.storeAllWrites(writes); len(errs) == 0 { + db.logger.Debugf("Successfully stored writes for %d namespaces", len(writes)) + } else if discarded, err := db.discard("", 0, 0, errs); err != nil { + return errors.Wrapf(err, "failed storing writes") + } else if discarded { + db.logger.Infof("Discarded changes while storing writes as duplicates. Skipping...") + for _, input := range inputs { db.txIDStore.Invalidate(input.txID) - return nil } + return nil + } - db.logger.Debugf("parse meta writes [%s]", input.txID) - span.AddEvent("store_meta_writes") - if discarded, err := db.storeMetaWrites(input.ctx, input.rws.MetaWrites, input.block, input.indexInBloc); err != nil { - return errors.Wrapf(err, "failed storing meta writes") - } else if discarded { - db.logger.Infof("Discarded changes while storing meta writes as duplicates. Skipping...") + db.logger.Debugf("parse meta writes") + metaWrites := make(map[driver.Namespace]map[driver.PKey]driver.VersionedMetadataValue) + for _, input := range inputs { + for ns, ws := range input.rws.MetaWrites { + vals := versionedMetaValues(ws, input.block, input.indexInBloc) + if nsWrites, ok := metaWrites[ns]; !ok { + metaWrites[ns] = vals + } else { + collections.CopyMap(nsWrites, vals) + } + } + } + if errs := db.storeAllMetaWrites(metaWrites); len(errs) == 0 { + db.logger.Debugf("Successfully stored meta writes for %d namespaces", len(metaWrites)) + } else if discarded, err := db.discard("", 0, 0, errs); err != nil { + return errors.Wrapf(err, "failed storing meta writes") + } else if discarded { + db.logger.Infof("Discarded changes while storing meta writes as duplicates. Skipping...") + for _, input := range inputs { db.txIDStore.Invalidate(input.txID) - return nil } + return nil + } - db.logger.Debugf("set state to valid [%s]", input.txID) - span.AddEvent("set_tx_valid") - if discarded, err := db.setTxValid(input.txID); err != nil { - return errors.Wrapf(err, "failed setting tx state to valid") - } else if discarded { - db.logger.Infof("Discarded changes while setting tx state to valid as duplicates. Skipping...") - return nil + if discarded, err := db.setStatuses(inputs, db.vcProvider.Valid()); err != nil { + if err1 := db.store.Discard(); err1 != nil { + db.logger.Errorf("got error %s; discarding caused %s", err.Error(), err1.Error()) } - + for _, input := range inputs { + db.txIDStore.Invalidate(input.txID) + } + return errors.Wrapf(err, "failed setting tx state to valid") + } else if discarded { + if err1 := db.store.Discard(); err1 != nil { + db.logger.Errorf("got unique key violation; discarding caused %s", err1.Error()) + } + db.logger.Infof("Discarded changes while setting tx state to valid as duplicates. Skipping...") + return nil } - for _, input := range inputs { - trace.SpanFromContext(input.ctx).AddEvent("commit_update") - } if err := db.store.Commit(); err != nil { return errors.Wrapf(err, "committing tx for txid in store [%v] failed", inputs) } - return nil } -func (db *Vault[V]) setTxValid(txID driver.TxID) (bool, error) { - err := db.txIDStore.Set(txID, db.vcProvider.Valid(), "") - if err == nil { - return false, nil +func (db *Vault[V]) setStatuses(inputs []commitInput, v V) (bool, error) { + txs := make([]driver.ByNum[V], len(inputs)) + for i, input := range inputs { + txs[i] = driver.ByNum[V]{TxID: input.txID, Code: v} } - if err1 := db.store.Discard(); err1 != nil { - db.logger.Errorf("got error %s; discarding caused %s", err.Error(), err1.Error()) - } - - if !errors.HasCause(err, UniqueKeyViolation) { - db.txIDStore.Invalidate(txID) - return true, errors.Wrapf(err, "error setting tx valid") + if err := db.txIDStore.SetMultiple(txs); err == nil { + return false, nil + } else if !errors.HasCause(err, UniqueKeyViolation) { + return true, err + } else { + return true, nil } - return true, nil } -func (db *Vault[V]) storeMetaWrites(ctx context.Context, writes NamespaceKeyedMetaWrites, block driver.BlockNum, indexInBloc driver.TxNum) (bool, error) { - span := trace.SpanFromContext(ctx) - for ns, keyMap := range writes { - span.AddEvent("set_tx_metadata_state") - if errs := db.store.SetStateMetadatas(ns, keyMap, block, indexInBloc); len(errs) > 0 { - return db.discard(ns, block, indexInBloc, errs) - } +func (db *Vault[V]) storeAllWrites(writes map[driver.Namespace]map[driver.PKey]VersionedValue) map[driver.PKey]error { + errs := make(map[driver.PKey]error) + for ns, vals := range writes { + collections.CopyMap(errs, db.store.SetStates(ns, vals)) } - return false, nil + return errs } -func (db *Vault[V]) storeWrites(ctx context.Context, writes Writes, block driver.BlockNum, indexInBloc driver.TxNum) (bool, error) { - span := trace.SpanFromContext(ctx) - for ns, keyMap := range writes { - span.AddEvent("set_tx_states") - if errs := db.store.SetStates(ns, versionedValues(keyMap, block, indexInBloc)); len(errs) > 0 { - return db.discard(ns, block, indexInBloc, errs) - } +func (db *Vault[V]) storeAllMetaWrites(metaWrites map[driver.Namespace]map[driver.PKey]driver.VersionedMetadataValue) map[driver.PKey]error { + errs := make(map[driver.PKey]error) + for ns, vals := range metaWrites { + collections.CopyMap(errs, db.store.SetStateMetadatas(ns, vals)) } - return false, nil + return errs } func versionedValues(keyMap NamespaceWrites, block driver.BlockNum, indexInBloc driver.TxNum) map[driver.PKey]VersionedValue { @@ -387,6 +404,14 @@ func versionedValues(keyMap NamespaceWrites, block driver.BlockNum, indexInBloc return vals } +func versionedMetaValues(keyMap KeyedMetaWrites, block driver.BlockNum, indexInBloc driver.TxNum) map[driver.PKey]driver.VersionedMetadataValue { + vals := make(map[driver.PKey]driver.VersionedMetadataValue, len(keyMap)) + for pkey, val := range keyMap { + vals[pkey] = driver.VersionedMetadataValue{Metadata: val, Block: block, TxNum: indexInBloc} + } + return vals +} + func (db *Vault[V]) discard(ns driver.Namespace, block driver.BlockNum, indexInBloc driver.TxNum, errs map[driver.PKey]error) (bool, error) { if err1 := db.store.Discard(); err1 != nil { db.logger.Errorf("got error %v; discarding caused %s", errors2.Join(collections.Values(errs)...), err1.Error()) diff --git a/platform/common/driver/vault.go b/platform/common/driver/vault.go index 7a1428694..f83775a60 100644 --- a/platform/common/driver/vault.go +++ b/platform/common/driver/vault.go @@ -26,6 +26,12 @@ type VersionedRead struct { TxNum TxNum } +type VersionedMetadataValue struct { + Block BlockNum + TxNum TxNum + Metadata Metadata +} + type VersionedResultsIterator = collections.Iterator[*VersionedRead] type QueryExecutor interface { diff --git a/platform/fabric/driver/iterator.go b/platform/fabric/driver/iterator.go index 35050d6db..a1bc61859 100644 --- a/platform/fabric/driver/iterator.go +++ b/platform/fabric/driver/iterator.go @@ -36,4 +36,5 @@ type TXIDStore interface { Iterator(pos interface{}) (TxIDIterator, error) Get(txid string) (ValidationCode, string, error) Set(txID string, code ValidationCode, message string) error + SetMultiple(txs []driver.ByNum[ValidationCode]) error } diff --git a/platform/view/services/db/driver/badger/badger.go b/platform/view/services/db/driver/badger/badger.go index 71c4bba12..e65d0eea6 100644 --- a/platform/view/services/db/driver/badger/badger.go +++ b/platform/view/services/db/driver/badger/badger.go @@ -173,10 +173,10 @@ func (db *DB) SetStateMetadata(namespace driver2.Namespace, key driver2.PKey, me return nil } -func (db *DB) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.Metadata, block driver2.BlockNum, txnum driver2.TxNum) map[driver2.PKey]error { +func (db *DB) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.VersionedMetadataValue) map[driver2.PKey]error { errs := make(map[driver2.PKey]error) for pkey, value := range kvs { - if err := db.SetStateMetadata(ns, pkey, value, block, txnum); err != nil { + if err := db.SetStateMetadata(ns, pkey, value.Metadata, value.Block, value.TxNum); err != nil { errs[pkey] = err } } diff --git a/platform/view/services/db/driver/driver.go b/platform/view/services/db/driver/driver.go index 21feffb78..90ca8e129 100644 --- a/platform/view/services/db/driver/driver.go +++ b/platform/view/services/db/driver/driver.go @@ -27,6 +27,8 @@ type VersionedValue struct { TxNum driver.TxNum } +type VersionedMetadataValue = driver.VersionedMetadataValue + type UnversionedRead struct { Key driver.PKey Raw driver.RawValue @@ -90,7 +92,7 @@ type VersionedPersistence interface { // SetStateMetadata sets the given metadata for the given namespace, key, and version SetStateMetadata(namespace driver.Namespace, key driver.PKey, metadata driver.Metadata, block driver.BlockNum, txnum driver.TxNum) error // SetStateMetadatas sets the given metadata for the given namespace, keys, and version - SetStateMetadatas(ns driver.Namespace, kvs map[driver.PKey]driver.Metadata, block driver.BlockNum, txnum driver.TxNum) map[driver.PKey]error + SetStateMetadatas(ns driver.Namespace, kvs map[driver.PKey]driver.VersionedMetadataValue) map[driver.PKey]error } type WriteTransaction interface { diff --git a/platform/view/services/db/driver/notifier/persistence.go b/platform/view/services/db/driver/notifier/persistence.go index b3cb5602f..41dc3d82e 100644 --- a/platform/view/services/db/driver/notifier/persistence.go +++ b/platform/view/services/db/driver/notifier/persistence.go @@ -199,8 +199,8 @@ func (db *VersionedPersistenceNotifier[P]) SetStateMetadata(namespace driver2.Na return db.Persistence.SetStateMetadata(namespace, key, metadata, block, txnum) } -func (db *VersionedPersistenceNotifier[P]) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.Metadata, block driver2.BlockNum, txnum driver2.TxNum) map[driver2.PKey]error { - return db.Persistence.SetStateMetadatas(ns, kvs, block, txnum) +func (db *VersionedPersistenceNotifier[P]) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.VersionedMetadataValue) map[driver2.PKey]error { + return db.Persistence.SetStateMetadatas(ns, kvs) } func (db *VersionedPersistenceNotifier[P]) GetStateRangeScanIterator(namespace driver2.Namespace, startKey, endKey driver2.PKey) (driver.VersionedResultsIterator, error) { diff --git a/platform/view/services/db/driver/sql/common/base.go b/platform/view/services/db/driver/sql/common/base.go index d331730d5..36da0b3ae 100644 --- a/platform/view/services/db/driver/sql/common/base.go +++ b/platform/view/services/db/driver/sql/common/base.go @@ -34,7 +34,7 @@ type readScanner[V any] interface { ReadValue(scannable) (V, error) } -type valueScanner[V any] interface { +type ValueScanner[V any] interface { readScanner[V] // WriteValue writes the values of the V struct in the order given by the Columns method WriteValue(V) []any @@ -47,12 +47,12 @@ type BasePersistence[V any, R any] struct { table string readScanner readScanner[R] - ValueScanner valueScanner[V] + ValueScanner ValueScanner[V] errorWrapper driver.SQLErrorWrapper ci Interpreter } -func NewBasePersistence[V any, R any](writeDB *sql.DB, readDB *sql.DB, table string, readScanner readScanner[R], valueScanner valueScanner[V], errorWrapper driver.SQLErrorWrapper, ci Interpreter, newTransaction func() (*sql.Tx, error)) *BasePersistence[V, R] { +func NewBasePersistence[V any, R any](writeDB *sql.DB, readDB *sql.DB, table string, readScanner readScanner[R], valueScanner ValueScanner[V], errorWrapper driver.SQLErrorWrapper, ci Interpreter, newTransaction func() (*sql.Tx, error)) *BasePersistence[V, R] { return &BasePersistence[V, R]{ BaseDB: common.NewBaseDB[*sql.Tx](func() (*sql.Tx, error) { return newTransaction() }), readDB: readDB, @@ -207,6 +207,20 @@ func (db *BasePersistence[V, R]) SetStateWithTx(tx *sql.Tx, ns driver2.Namespace val = append([]byte(nil), val...) values[valIndex] = val + return db.UpsertStateWithTx(tx, ns, pkey, keys, values) +} + +func (db *BasePersistence[V, R]) UpsertStates(ns driver2.Namespace, valueKeys []string, vals map[driver2.PKey][]any) map[driver2.PKey]error { + errs := make(map[driver2.PKey]error) + for pkey, val := range vals { + if err := db.UpsertStateWithTx(db.Txn, ns, pkey, valueKeys, val); err != nil { + errs[pkey] = err + } + } + return errs +} + +func (db *BasePersistence[V, R]) UpsertStateWithTx(tx *sql.Tx, ns driver2.Namespace, pkey driver2.PKey, keys []string, values []any) error { // Portable upsert exists, err := db.exists(tx, ns, pkey) if err != nil { diff --git a/platform/view/services/db/driver/sql/common/versioned.go b/platform/view/services/db/driver/sql/common/versioned.go index 4a36fcf00..619d72b12 100644 --- a/platform/view/services/db/driver/sql/common/versioned.go +++ b/platform/view/services/db/driver/sql/common/versioned.go @@ -11,6 +11,7 @@ import ( "database/sql" "encoding/gob" "fmt" + "strings" "github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors" driver2 "github.com/hyperledger-labs/fabric-smart-client/platform/common/driver" @@ -22,6 +23,10 @@ func NewVersionedReadScanner() *versionedReadScanner { return &versionedReadScan func NewVersionedValueScanner() *versionedValueScanner { return &versionedValueScanner{} } +func NewVersionedMetadataValueScanner() *versionedMetadataValueScanner { + return &versionedMetadataValueScanner{} +} + type versionedReadScanner struct{} func (s *versionedReadScanner) Columns() []string { return []string{"pkey", "block", "txnum", "val"} } @@ -46,24 +51,54 @@ func (s *versionedValueScanner) WriteValue(value driver.VersionedValue) []any { return []any{value.Raw, value.Block, value.TxNum} } +type versionedMetadataValueScanner struct{} + +func (s *versionedMetadataValueScanner) Columns() []string { + return []string{"metadata", "block", "txnum"} +} + +func (s *versionedMetadataValueScanner) ReadValue(txs scannable) (driver2.VersionedMetadataValue, error) { + var r driver2.VersionedMetadataValue + var metadata []byte + if err := txs.Scan(&metadata, &r.Block, &r.TxNum); err != nil { + return r, err + } else if meta, err := unmarshalMetadata(metadata); err != nil { + return r, fmt.Errorf("error decoding metadata: %w", err) + } else { + r.Metadata = meta + } + return r, nil +} + +func (s *versionedMetadataValueScanner) WriteValue(value driver2.VersionedMetadataValue) ([]any, error) { + metadata, err := marshallMetadata(value.Metadata) + if err != nil { + return nil, err + } + return []any{metadata, value.Block, value.TxNum}, nil +} + type basePersistence[V any, R any] interface { driver.BasePersistence[driver.VersionedValue, driver.VersionedRead] + hasKey(ns driver2.Namespace, pkey string) Condition Exists(namespace driver2.Namespace, key driver2.PKey) (bool, error) Exec(query string, args ...any) (sql.Result, error) SetStateWithTx(tx *sql.Tx, namespace driver2.Namespace, key string, value driver.VersionedValue) error DeleteStateWithTx(tx *sql.Tx, ns driver2.Namespace, key driver2.PKey) error + UpsertStates(ns driver2.Namespace, valueKeys []string, vals map[driver2.PKey][]any) map[driver2.PKey]error } type VersionedPersistence struct { basePersistence[driver.VersionedValue, driver.VersionedRead] - table string - errorWrapper driver.SQLErrorWrapper - readDB *sql.DB - writeDB *sql.DB + table string + errorWrapper driver.SQLErrorWrapper + readDB *sql.DB + writeDB *sql.DB + metadataScanner *versionedMetadataValueScanner } func NewVersionedPersistence(base basePersistence[driver.VersionedValue, driver.VersionedRead], table string, errorWrapper driver.SQLErrorWrapper, readDB *sql.DB, writeDB *sql.DB) *VersionedPersistence { - return &VersionedPersistence{basePersistence: base, table: table, errorWrapper: errorWrapper, readDB: readDB, writeDB: writeDB} + return &VersionedPersistence{basePersistence: base, table: table, errorWrapper: errorWrapper, readDB: readDB, writeDB: writeDB, metadataScanner: NewVersionedMetadataValueScanner()} } func NewVersioned(readDB *sql.DB, writeDB *sql.DB, table string, errorWrapper driver.SQLErrorWrapper, ci Interpreter) *VersionedPersistence { @@ -72,74 +107,43 @@ func NewVersioned(readDB *sql.DB, writeDB *sql.DB, table string, errorWrapper dr } func (db *VersionedPersistence) SetStateMetadata(ns driver2.Namespace, key driver2.PKey, metadata driver2.Metadata, block driver2.BlockNum, txnum driver2.TxNum) error { - if ns == "" || key == "" { - return errors.New("ns or key is empty") - } if len(metadata) == 0 { return nil } - m, err := marshallMetadata(metadata) - if err != nil { - return fmt.Errorf("error encoding metadata: %w", err) - } - - exists, err := db.Exists(ns, key) - if err != nil { - return err - } - if exists { - // Note: for consistency with badger we also update the block and txnum - query := fmt.Sprintf("UPDATE %s SET metadata = $1, block = $2, txnum = $3 WHERE ns = $4 AND pkey = $5", db.table) - logger.Debug(query, len(m), block, txnum, ns, key) - _, err = db.Exec(query, m, block, txnum, ns, key) - if err != nil { - return errors2.Wrapf(db.errorWrapper.WrapError(err), "could not set metadata for key [%s]", key) - } - } else { - logger.Warnf("storing metadata without existing value at [%s]", key) - query := fmt.Sprintf("INSERT INTO %s (ns, pkey, metadata, block, txnum) VALUES ($1, $2, $3, $4, $5)", db.table) - logger.Debug(query, ns, key, len(m), block, txnum) - _, err = db.Exec(query, ns, key, m, block, txnum) - if err != nil { - return errors2.Wrapf(db.errorWrapper.WrapError(err), "could not set metadata for key [%s]", key) - } - } - - return nil + return db.SetStateMetadatas(ns, map[driver2.PKey]driver2.VersionedMetadataValue{key: {Block: block, TxNum: txnum, Metadata: metadata}})[key] } -func (db *VersionedPersistence) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.Metadata, block driver2.BlockNum, txnum driver2.TxNum) map[driver2.PKey]error { +func (db *VersionedPersistence) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.VersionedMetadataValue) map[driver2.PKey]error { errs := make(map[driver2.PKey]error) - for pkey, value := range kvs { - if err := db.SetStateMetadata(ns, pkey, value, block, txnum); err != nil { + vals := make(map[driver2.PKey][]any, len(kvs)) + for pkey, meta := range kvs { + if val, err := db.metadataScanner.WriteValue(meta); err != nil { errs[pkey] = err + } else { + vals[pkey] = val } } - return errs + if len(errs) > 0 { + return errs + } + return db.UpsertStates(ns, db.metadataScanner.Columns(), vals) } -func (db *VersionedPersistence) GetStateMetadata(ns driver2.Namespace, key driver2.PKey) (driver2.Metadata, driver2.BlockNum, driver2.TxNum, error) { - var m []byte - var meta map[string][]byte - var block, txnum uint64 +// TODO: AF Reuse code from basePersistence +func (db *VersionedPersistence) GetStateMetadata(namespace driver2.Namespace, key driver2.PKey) (driver2.Metadata, driver2.BlockNum, driver2.TxNum, error) { + where, args := Where(db.hasKey(namespace, key)) + query := fmt.Sprintf("SELECT %s FROM %s %s", strings.Join(db.metadataScanner.Columns(), ", "), db.table, where) + logger.Debug(query, args) - query := fmt.Sprintf("SELECT metadata, block, txnum FROM %s WHERE ns = $1 AND pkey = $2", db.table) - logger.Debug(query, ns, key) - - row := db.readDB.QueryRow(query, ns, key) - if err := row.Scan(&m, &block, &txnum); err != nil { - if err == sql.ErrNoRows { - logger.Debugf("not found: [%s:%s]", ns, key) - return meta, block, txnum, nil - } - return meta, block, txnum, fmt.Errorf("error querying db: %w", err) - } - meta, err := unmarshalMetadata(m) - if err != nil { - return meta, block, txnum, fmt.Errorf("error decoding metadata: %w", err) + row := db.readDB.QueryRow(query, args...) + if value, err := db.metadataScanner.ReadValue(row); err == nil { + return value.Metadata, value.Block, value.TxNum, nil + } else if err == sql.ErrNoRows { + logger.Debugf("not found: [%s:%s]", namespace, key) + return nil, 0, 0, nil + } else { + return nil, 0, 0, errors2.Wrapf(err, "error querying db: %s", query) } - - return meta, block, txnum, err } func (db *VersionedPersistence) CreateSchema() error { diff --git a/platform/view/services/db/driver/sql/postgres/base.go b/platform/view/services/db/driver/sql/postgres/base.go index d1263eef5..9a2705e46 100644 --- a/platform/view/services/db/driver/sql/postgres/base.go +++ b/platform/view/services/db/driver/sql/postgres/base.go @@ -46,12 +46,11 @@ func (db *BasePersistence[V, R]) setStatesWithTx(tx *sql.Tx, ns driver.Namespace if tx == nil { panic("programming error, writing without ongoing update") } - keys := db.valueKeys() + keys := db.ValueScanner.Columns() valIndex := slices.Index(keys, "val") upserted := make(map[driver.PKey][]any, len(kvs)) deleted := make([]driver.PKey, 0, len(kvs)) for pkey, value := range kvs { - values := db.ValueScanner.WriteValue(value) // Get rawVal if val := values[valIndex].([]byte); len(val) == 0 { @@ -66,46 +65,44 @@ func (db *BasePersistence[V, R]) setStatesWithTx(tx *sql.Tx, ns driver.Namespace } } - var errs map[driver.PKey]error + errs := make(map[driver.PKey]error) if len(deleted) > 0 { collections.CopyMap(errs, db.DeleteStatesWithTx(tx, ns, deleted...)) } if len(upserted) > 0 { - collections.CopyMap(errs, db.upsertStatesWithTx(tx, ns, upserted)) + collections.CopyMap(errs, db.upsertStatesWithTx(tx, ns, keys, upserted)) } return errs } -func (db *BasePersistence[V, R]) upsertStatesWithTx(tx *sql.Tx, ns driver.Namespace, upserted map[driver.PKey][]any) map[driver.PKey]error { - keys := append([]string{"ns", "pkey"}, db.valueKeys()...) +func (db *BasePersistence[V, R]) UpsertStates(ns driver.Namespace, valueKeys []string, vals map[driver.PKey][]any) map[driver.PKey]error { + return db.upsertStatesWithTx(db.Txn, ns, valueKeys, vals) +} + +func (db *BasePersistence[V, R]) upsertStatesWithTx(tx *sql.Tx, ns driver.Namespace, valueKeys []string, vals map[driver.PKey][]any) map[driver.PKey]error { + keys := append([]string{"ns", "pkey"}, valueKeys...) query := fmt.Sprintf("INSERT INTO %s (%s) "+ "VALUES %s "+ "ON CONFLICT (ns, pkey) DO UPDATE "+ "SET %s", db.table, strings.Join(keys, ", "), - common.CreateParamsMatrix(len(keys), len(upserted), 1), - strings.Join(db.substitutions(), ", ")) + common.CreateParamsMatrix(len(keys), len(vals), 1), + strings.Join(substitutions(valueKeys), ", ")) - args := make([]any, 0, len(keys)*len(upserted)) - for pkey, vals := range upserted { + args := make([]any, 0, len(keys)*len(vals)) + for pkey, vals := range vals { args = append(append(args, ns, pkey), vals...) } logger.Debug(query, args) if _, err := tx.Exec(query, args...); err != nil { - return collections.RepeatValue(collections.Keys(upserted), errors2.Wrapf(db.errorWrapper.WrapError(err), "could not upsert")) + return collections.RepeatValue(collections.Keys(vals), errors2.Wrapf(db.errorWrapper.WrapError(err), "could not upsert")) } return nil } // TODO: AF Needs to be calculated only once -func (db *BasePersistence[V, R]) valueKeys() []string { - return db.ValueScanner.Columns() -} - -// TODO: AF Needs to be calculated only once -func (db *BasePersistence[V, R]) substitutions() []string { - keys := db.valueKeys() +func substitutions(keys []string) []string { subs := make([]string, len(keys)) for i, key := range keys { subs[i] = fmt.Sprintf("%s = excluded.%s", key, key) diff --git a/platform/view/services/db/driver/sql/postgres/versioned.go b/platform/view/services/db/driver/sql/postgres/versioned.go index 92b710dcf..493cd135e 100644 --- a/platform/view/services/db/driver/sql/postgres/versioned.go +++ b/platform/view/services/db/driver/sql/postgres/versioned.go @@ -83,8 +83,8 @@ func (db *VersionedPersistence) SetStateMetadata(namespace driver2.Namespace, ke return db.p.SetStateMetadata(namespace, key, metadata, block, txnum) } -func (db *VersionedPersistence) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.Metadata, block driver2.BlockNum, txnum driver2.TxNum) map[driver2.PKey]error { - return db.p.SetStateMetadatas(ns, kvs, block, txnum) +func (db *VersionedPersistence) SetStateMetadatas(ns driver2.Namespace, kvs map[driver2.PKey]driver2.VersionedMetadataValue) map[driver2.PKey]error { + return db.p.SetStateMetadatas(ns, kvs) } func (db *VersionedPersistence) CreateSchema() error {