From 1dbd42574aa9ebde789bb5a98edd800a42d28ae9 Mon Sep 17 00:00:00 2001 From: Paul Nicolas Date: Wed, 29 Mar 2023 14:25:29 +0200 Subject: [PATCH] feat: optimize queries package (#187) --- pkg/api/internal/testing.go | 14 +- pkg/api/middlewares/ledger_middleware.go | 4 +- pkg/bus/module.go | 4 +- pkg/bus/monitor.go | 4 +- pkg/core/transaction.go | 5 + pkg/ledger/ledger.go | 5 + pkg/ledger/main_test.go | 14 +- pkg/ledger/module.go | 8 +- pkg/ledger/{query => monitor}/monitor.go | 2 +- pkg/ledger/query/module.go | 35 -- pkg/ledger/query/worker.go | 410 +++++++++++++----- pkg/ledger/query/worker_test.go | 5 +- pkg/ledger/resolver.go | 23 +- pkg/ledgertesting/storage.go | 4 + pkg/storage/driver.go | 32 +- pkg/storage/sqlstorage/ledger/accounts.go | 43 +- .../sqlstorage/ledger/accounts_test.go | 71 +++ pkg/storage/sqlstorage/ledger/logs.go | 15 +- pkg/storage/sqlstorage/ledger/store.go | 31 ++ pkg/storage/sqlstorage/ledger/transactions.go | 30 ++ .../sqlstorage/ledger/transactions_test.go | 147 +++++++ pkg/storage/sqlstorage/ledger/volumes.go | 45 +- pkg/storage/sqlstorage/ledger/volumes_test.go | 92 ++++ pkg/storage/sqlstorage/schema/schema.go | 25 +- 24 files changed, 846 insertions(+), 222 deletions(-) rename pkg/ledger/{query => monitor}/monitor.go (98%) delete mode 100644 pkg/ledger/query/module.go create mode 100644 pkg/storage/sqlstorage/ledger/transactions_test.go create mode 100644 pkg/storage/sqlstorage/ledger/volumes_test.go diff --git a/pkg/api/internal/testing.go b/pkg/api/internal/testing.go index 392e4fbbe..4d05c88e4 100644 --- a/pkg/api/internal/testing.go +++ b/pkg/api/internal/testing.go @@ -16,7 +16,7 @@ import ( "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/ledger" "github.com/formancehq/ledger/pkg/ledger/lock" - "github.com/formancehq/ledger/pkg/ledger/query" + "github.com/formancehq/ledger/pkg/ledger/monitor" "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" sharedapi "github.com/formancehq/stack/libs/go-libs/api" @@ -206,12 +206,14 @@ func RunTest(t *testing.T, callback func(api chi.Router, storageDriver storage.D storageDriver := ledgertesting.StorageDriver(t) require.NoError(t, storageDriver.Initialize(context.Background())) - queryWorker := query.NewWorker(query.DefaultWorkerConfig, storageDriver, query.NewNoOpMonitor()) - go func() { - require.NoError(t, queryWorker.Run(context.Background())) - }() + ledgerStore, _, err := storageDriver.GetLedgerStore(context.Background(), uuid.New(), true) + require.NoError(t, err) + + modified, err := ledgerStore.Initialize(context.Background()) + require.NoError(t, err) + require.True(t, modified) - resolver := ledger.NewResolver(storageDriver, lock.NewInMemory(), queryWorker, false) + resolver := ledger.NewResolver(storageDriver, monitor.NewNoOpMonitor(), lock.NewInMemory(), false) router := routes.NewRouter(storageDriver, "latest", resolver, logging.FromContext(context.Background()), &health.HealthController{}) diff --git a/pkg/api/middlewares/ledger_middleware.go b/pkg/api/middlewares/ledger_middleware.go index 10729fc5f..6140ec5c2 100644 --- a/pkg/api/middlewares/ledger_middleware.go +++ b/pkg/api/middlewares/ledger_middleware.go @@ -1,7 +1,6 @@ package middlewares import ( - "context" "net/http" "github.com/formancehq/ledger/pkg/api/apierrors" @@ -34,7 +33,8 @@ func LedgerMiddleware(resolver *ledger.Resolver) func(handler http.Handler) http apierrors.ResponseError(w, r, err) return } - defer l.Close(context.Background()) + // TODO(polo/gfyrag): close ledger if not used for x minutes + // defer l.Close(context.Background()) r = r.WithContext(controllers.ContextWithLedger(r.Context(), l)) diff --git a/pkg/bus/module.go b/pkg/bus/module.go index 2afd78990..378194b0c 100644 --- a/pkg/bus/module.go +++ b/pkg/bus/module.go @@ -1,10 +1,10 @@ package bus import ( - "github.com/formancehq/ledger/pkg/ledger/query" + "github.com/formancehq/ledger/pkg/ledger/monitor" "go.uber.org/fx" ) func LedgerMonitorModule() fx.Option { - return fx.Decorate(fx.Annotate(newLedgerMonitor, fx.As(new(query.Monitor)))) + return fx.Decorate(fx.Annotate(newLedgerMonitor, fx.As(new(monitor.Monitor)))) } diff --git a/pkg/bus/monitor.go b/pkg/bus/monitor.go index 577291b7f..8807d56ac 100644 --- a/pkg/bus/monitor.go +++ b/pkg/bus/monitor.go @@ -5,7 +5,7 @@ import ( "github.com/ThreeDotsLabs/watermill/message" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledger/query" + "github.com/formancehq/ledger/pkg/ledger/monitor" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/formancehq/stack/libs/go-libs/publish" ) @@ -14,7 +14,7 @@ type ledgerMonitor struct { publisher message.Publisher } -var _ query.Monitor = &ledgerMonitor{} +var _ monitor.Monitor = &ledgerMonitor{} func newLedgerMonitor(publisher message.Publisher) *ledgerMonitor { m := &ledgerMonitor{ diff --git a/pkg/core/transaction.go b/pkg/core/transaction.go index 408bafe1d..7f57e9b5e 100644 --- a/pkg/core/transaction.go +++ b/pkg/core/transaction.go @@ -48,6 +48,11 @@ type Transaction struct { ID uint64 `json:"txid"` } +type TransactionWithMetadata struct { + ID uint64 + Metadata Metadata +} + func (t Transaction) WithPostings(postings ...Posting) Transaction { t.TransactionData = t.TransactionData.WithPostings(postings...) return t diff --git a/pkg/ledger/ledger.go b/pkg/ledger/ledger.go index 71e27759b..8e3387640 100644 --- a/pkg/ledger/ledger.go +++ b/pkg/ledger/ledger.go @@ -42,6 +42,11 @@ func (l *Ledger) Close(ctx context.Context) error { if err := l.store.Close(ctx); err != nil { return errors.Wrap(err, "closing store") } + + if err := l.queryWorker.Stop(ctx); err != nil { + return errors.Wrap(err, "stopping query worker") + } + return nil } diff --git a/pkg/ledger/main_test.go b/pkg/ledger/main_test.go index 9ba61308c..29bf6ff83 100644 --- a/pkg/ledger/main_test.go +++ b/pkg/ledger/main_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/formancehq/ledger/pkg/ledger/lock" - "github.com/formancehq/ledger/pkg/ledger/query" + "github.com/formancehq/ledger/pkg/ledger/monitor" "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/formancehq/stack/libs/go-libs/pgtesting" @@ -30,12 +30,14 @@ func newResolver(t interface{ pgtesting.TestingT }) *Resolver { storageDriver := ledgertesting.StorageDriver(t) require.NoError(t, storageDriver.Initialize(context.Background())) - queryWorker := query.NewWorker(query.DefaultWorkerConfig, storageDriver, query.NewNoOpMonitor()) - go func() { - require.NoError(t, queryWorker.Run(context.Background())) - }() + ledgerStore, _, err := storageDriver.GetLedgerStore(context.Background(), uuid.New(), true) + require.NoError(t, err) + + modified, err := ledgerStore.Initialize(context.Background()) + require.NoError(t, err) + require.True(t, modified) - return NewResolver(storageDriver, lock.NewInMemory(), queryWorker, false) + return NewResolver(storageDriver, monitor.NewNoOpMonitor(), lock.NewInMemory(), false) } func runOnLedger(t interface { diff --git a/pkg/ledger/module.go b/pkg/ledger/module.go index 702e7b1db..3f8e6a45f 100644 --- a/pkg/ledger/module.go +++ b/pkg/ledger/module.go @@ -2,7 +2,7 @@ package ledger import ( "github.com/formancehq/ledger/pkg/ledger/lock" - "github.com/formancehq/ledger/pkg/ledger/query" + "github.com/formancehq/ledger/pkg/ledger/monitor" "github.com/formancehq/ledger/pkg/storage" "go.uber.org/fx" ) @@ -12,11 +12,11 @@ func Module(allowPastTimestamp bool) fx.Option { lock.Module(), fx.Provide(func( storageDriver storage.Driver, + monitor monitor.Monitor, locker lock.Locker, - queryWorker *query.Worker, ) *Resolver { - return NewResolver(storageDriver, locker, queryWorker, allowPastTimestamp) + return NewResolver(storageDriver, monitor, locker, allowPastTimestamp) }), - query.Module(), + fx.Provide(fx.Annotate(monitor.NewNoOpMonitor, fx.As(new(monitor.Monitor)))), ) } diff --git a/pkg/ledger/query/monitor.go b/pkg/ledger/monitor/monitor.go similarity index 98% rename from pkg/ledger/query/monitor.go rename to pkg/ledger/monitor/monitor.go index 4dc6fc417..2fbb269eb 100644 --- a/pkg/ledger/query/monitor.go +++ b/pkg/ledger/monitor/monitor.go @@ -1,4 +1,4 @@ -package query +package monitor import ( "context" diff --git a/pkg/ledger/query/module.go b/pkg/ledger/query/module.go deleted file mode 100644 index 24769234a..000000000 --- a/pkg/ledger/query/module.go +++ /dev/null @@ -1,35 +0,0 @@ -package query - -import ( - "context" - - "github.com/formancehq/stack/libs/go-libs/logging" - "go.uber.org/fx" -) - -func Module() fx.Option { - return fx.Options( - fx.Supply(workerConfig{ - // TODO(gfyrag): Probably need to be configurable - ChanSize: 1024, - }), - fx.Provide(NewWorker), - fx.Provide(fx.Annotate(NewNoOpMonitor, fx.As(new(Monitor)))), - fx.Invoke(func(worker *Worker, lc fx.Lifecycle) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - go func() { - if err := worker.Run(logging.ContextWithLogger( - context.Background(), - logging.FromContext(ctx), - )); err != nil { - panic(err) - } - }() - return nil - }, - OnStop: worker.Stop, - }) - }), - ) -} diff --git a/pkg/ledger/query/worker.go b/pkg/ledger/query/worker.go index 8e7b64ee7..2d1970127 100644 --- a/pkg/ledger/query/worker.go +++ b/pkg/ledger/query/worker.go @@ -6,33 +6,45 @@ import ( "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/ledger/aggregator" + "github.com/formancehq/ledger/pkg/ledger/monitor" "github.com/formancehq/ledger/pkg/storage" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/pkg/errors" ) var ( - DefaultWorkerConfig = workerConfig{ + DefaultWorkerConfig = WorkerConfig{ ChanSize: 100, } ) -type workerConfig struct { +type WorkerConfig struct { ChanSize int } -type logHolder struct { - *core.LogHolder - store storage.LedgerStore +type logsData struct { + accountsToUpdate []core.Account + ensureAccountsExist []string + transactionsToInsert []core.ExpandedTransaction + transactionsToUpdate []core.TransactionWithMetadata + volumesToUpdate []core.AccountsAssetsVolumes + monitors []func(context.Context, monitor.Monitor) } type Worker struct { - workerConfig - ctx context.Context - logChan chan logHolder - stopChan chan chan struct{} + WorkerConfig + ctx context.Context + + pending []*core.LogHolder + writeChannel chan *core.LogHolder + jobs chan []*core.LogHolder + releasedJob chan struct{} + errorChan chan error + stopChan chan chan struct{} + driver storage.Driver - monitor Monitor + store storage.LedgerStore + monitor monitor.Monitor lastProcessedLogID *uint64 } @@ -64,6 +76,54 @@ func (w *Worker) Run(ctx context.Context) error { } } +func (w *Worker) writeLoop(ctx context.Context) { + closeLogs := func(logs []*core.LogHolder) { + for _, log := range logs { + close(log.Ingested) + } + } + + for { + select { + case <-ctx.Done(): + return + + case w.releasedJob <- struct{}{}: + + case modelsHolder := <-w.jobs: + logs := make([]core.Log, len(modelsHolder)) + for i, holder := range modelsHolder { + logs[i] = *holder.Log + } + + if err := w.processLogs(w.ctx, logs...); err != nil { + if err == context.Canceled { + logging.FromContext(w.ctx).Debugf("CQRS worker canceled") + } else { + logging.FromContext(w.ctx).Errorf("CQRS worker error: %s", err) + } + closeLogs(modelsHolder) + + // Return the error to restart the worker + w.errorChan <- err + return + } + + if err := w.store.UpdateNextLogID(w.ctx, logs[len(logs)-1].ID+1); err != nil { + logging.FromContext(w.ctx).Errorf("CQRS worker error: %s", err) + closeLogs(modelsHolder) + + // TODO(polo/gfyrag): add indempotency tests + // Return the error to restart the worker + w.errorChan <- err + return + } + + closeLogs(modelsHolder) + } + } +} + func (w *Worker) run() error { if err := w.initLedgers(w.ctx); err != nil { if err == context.Canceled { @@ -75,39 +135,86 @@ func (w *Worker) run() error { return err } + ctx, cancel := context.WithCancel(w.ctx) + defer cancel() + + go w.writeLoop(ctx) + +l: for { select { - case <-w.ctx.Done(): - return w.ctx.Err() + case <-ctx.Done(): + return ctx.Err() + case stopChan := <-w.stopChan: - logging.FromContext(w.ctx).Debugf("CQRS worker stopped") + logging.FromContext(ctx).Debugf("CQRS worker stopped") close(stopChan) return nil - case wl := <-w.logChan: + + case err := <-w.errorChan: + // In this case, we failed to write the models, so we need to + // restart the worker + logging.FromContext(ctx).Debugf("write loop error: %s", err) + return err + + // At this level, the job is writting some models, just accumulate models in a buffer + case wl := <-w.writeChannel: if w.lastProcessedLogID != nil && wl.Log.ID <= *w.lastProcessedLogID { close(wl.Ingested) continue } - if err := w.processLog(w.ctx, wl.store, wl.Log); err != nil { - if err == context.Canceled { - logging.FromContext(w.ctx).Debugf("CQRS worker canceled") - } else { - logging.FromContext(w.ctx).Errorf("CQRS worker error: %s", err) - } - close(wl.Ingested) - // Return the error to restart the worker - return err + w.pending = append(w.pending, wl) + case <-w.releasedJob: + // There, write model job is not running, and we have pending models + // So we can try to send pending to the job channel + if len(w.pending) > 0 { + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case stopChan := <-w.stopChan: + logging.FromContext(ctx).Debugf("CQRS worker stopped") + close(stopChan) + return nil + + case err := <-w.errorChan: + logging.FromContext(ctx).Debugf("write loop error: %s", err) + return err + + case w.jobs <- w.pending: + w.pending = make([]*core.LogHolder, 0) + continue l + } + } } + select { + case <-ctx.Done(): + return ctx.Err() - if err := wl.store.UpdateNextLogID(w.ctx, wl.Log.ID+1); err != nil { - logging.FromContext(w.ctx).Errorf("CQRS worker error: %s", err) - close(wl.Ingested) - // TODO(polo/gfyrag): add indempotency tests - // Return the error to restart the worker + case stopChan := <-w.stopChan: + logging.FromContext(ctx).Debugf("CQRS worker stopped") + close(stopChan) + return nil + + case err := <-w.errorChan: + logging.FromContext(ctx).Debugf("write loop error: %s", err) return err + + // There, the job is waiting, and we don't have any pending models to write + // so, wait for new models to write and send them directly to the job channel + // We can not return to the main loop as w.releasedJob will be continuously notified by the job routine + case mh := <-w.writeChannel: + select { + case <-ctx.Done(): + return ctx.Err() + case stopChan := <-w.stopChan: + close(stopChan) + return nil + case w.jobs <- []*core.LogHolder{mh}: + } } - close(wl.Ingested) } } } @@ -171,7 +278,7 @@ func (w *Worker) initLedger(ctx context.Context, ledger string) error { return nil } - if err := w.processLogs(ctx, store, logs...); err != nil { + if err := w.processLogs(ctx, logs...); err != nil { return errors.Wrap(err, "processing logs") } @@ -184,125 +291,204 @@ func (w *Worker) initLedger(ctx context.Context, ledger string) error { return nil } -func (w *Worker) processLogs(ctx context.Context, store storage.LedgerStore, logs ...core.Log) error { - for _, log := range logs { - if err := w.processLog(ctx, store, &log); err != nil { - return errors.Wrapf(err, "processing log %d", log.ID) - } +func (w *Worker) processLogs(ctx context.Context, logs ...core.Log) error { + + logsData, err := w.buildData(ctx, logs...) + if err != nil { + return errors.Wrap(err, "building data") } - return nil -} + if err := w.store.RunInTransaction(ctx, func(ctx context.Context, tx storage.LedgerStore) error { + if len(logsData.accountsToUpdate) > 0 { + if err := tx.UpdateAccountsMetadata(ctx, logsData.accountsToUpdate); err != nil { + return errors.Wrap(err, "updating accounts metadata") + } + } -func (w *Worker) processLog(ctx context.Context, store storage.LedgerStore, log *core.Log) error { - volumeAggregator := aggregator.Volumes(store) + if len(logsData.transactionsToInsert) > 0 { + if err := tx.InsertTransactions(ctx, logsData.transactionsToInsert...); err != nil { + return errors.Wrap(err, "inserting transactions") + } + } - var err error - switch log.Type { - case core.NewTransactionLogType: - payload := log.Data.(core.NewTransactionLogPayload) - txVolumeAggregator, err := volumeAggregator.NextTxWithPostings(ctx, payload.Transaction.Postings...) - if err != nil { - return err + if len(logsData.transactionsToUpdate) > 0 { + if err := tx.UpdateTransactionsMetadata(ctx, logsData.transactionsToUpdate...); err != nil { + return errors.Wrap(err, "updating transactions") + } } - if payload.AccountMetadata != nil { - for account, metadata := range payload.AccountMetadata { - if err := store.UpdateAccountMetadata(ctx, account, metadata); err != nil { - return errors.Wrap(err, "updating account metadata") - } + if len(logsData.ensureAccountsExist) > 0 { + if err := tx.EnsureAccountsExist(ctx, logsData.ensureAccountsExist); err != nil { + return errors.Wrap(err, "ensuring accounts exist") } } - expandedTx := core.ExpandedTransaction{ - Transaction: payload.Transaction, - PreCommitVolumes: txVolumeAggregator.PreCommitVolumes, - PostCommitVolumes: txVolumeAggregator.PostCommitVolumes, + if len(logsData.volumesToUpdate) > 0 { + return tx.UpdateVolumes(ctx, logsData.volumesToUpdate...) } - if err := store.InsertTransactions(ctx, expandedTx); err != nil { - return errors.Wrap(err, "inserting transactions") + return nil + }); err != nil { + return err + } + + if w.monitor != nil { + for _, monitor := range logsData.monitors { + monitor(ctx, w.monitor) } + } + + return nil +} - for account := range txVolumeAggregator.PostCommitVolumes { - if err := store.EnsureAccountExists(ctx, account); err != nil { - return errors.Wrap(err, "ensuring account exists") +func (w *Worker) buildData( + ctx context.Context, + logs ...core.Log, +) (*logsData, error) { + logsData := &logsData{} + + volumeAggregator := aggregator.Volumes(w.store) + accountsToUpdate := make(map[string]core.Metadata) + transactionsToUpdate := make(map[uint64]core.Metadata) + for _, log := range logs { + switch log.Type { + case core.NewTransactionLogType: + payload := log.Data.(core.NewTransactionLogPayload) + txVolumeAggregator, err := volumeAggregator.NextTxWithPostings(ctx, payload.Transaction.Postings...) + if err != nil { + return nil, err } - } - if err := store.UpdateVolumes(ctx, txVolumeAggregator.PostCommitVolumes); err != nil { - return errors.Wrap(err, "updating volumes") - } + if payload.AccountMetadata != nil { + for account, metadata := range payload.AccountMetadata { + if m, ok := accountsToUpdate[account]; !ok { + accountsToUpdate[account] = metadata + } else { + for k, v := range metadata { + m[k] = v + } + } + } + } - if w.monitor != nil { - w.monitor.CommittedTransactions(ctx, store.Name(), expandedTx) - for account, metadata := range payload.AccountMetadata { - w.monitor.SavedMetadata(ctx, store.Name(), core.MetaTargetTypeAccount, account, metadata) + expandedTx := core.ExpandedTransaction{ + Transaction: payload.Transaction, + PreCommitVolumes: txVolumeAggregator.PreCommitVolumes, + PostCommitVolumes: txVolumeAggregator.PostCommitVolumes, } - } - case core.SetMetadataLogType: - setMetadata := log.Data.(core.SetMetadataLogPayload) - switch setMetadata.TargetType { - case core.MetaTargetTypeAccount: - if err := store.UpdateAccountMetadata(ctx, setMetadata.TargetID.(string), setMetadata.Metadata); err != nil { - return errors.Wrap(err, "updating account metadata") + logsData.transactionsToInsert = append(logsData.transactionsToInsert, expandedTx) + + for account := range txVolumeAggregator.PostCommitVolumes { + logsData.ensureAccountsExist = append(logsData.ensureAccountsExist, account) } - case core.MetaTargetTypeTransaction: - if err := store.UpdateTransactionMetadata(ctx, setMetadata.TargetID.(uint64), setMetadata.Metadata); err != nil { - return errors.Wrap(err, "updating transactions metadata") + + logsData.volumesToUpdate = append(logsData.volumesToUpdate, txVolumeAggregator.PostCommitVolumes) + + logsData.monitors = append(logsData.monitors, func(ctx context.Context, monitor monitor.Monitor) { + w.monitor.CommittedTransactions(ctx, w.store.Name(), expandedTx) + for account, metadata := range payload.AccountMetadata { + w.monitor.SavedMetadata(ctx, w.store.Name(), core.MetaTargetTypeAccount, account, metadata) + } + }) + + case core.SetMetadataLogType: + setMetadata := log.Data.(core.SetMetadataLogPayload) + switch setMetadata.TargetType { + case core.MetaTargetTypeAccount: + addr := setMetadata.TargetID.(string) + if m, ok := accountsToUpdate[addr]; !ok { + accountsToUpdate[addr] = setMetadata.Metadata + } else { + for k, v := range setMetadata.Metadata { + m[k] = v + } + } + + case core.MetaTargetTypeTransaction: + id := setMetadata.TargetID.(uint64) + if m, ok := transactionsToUpdate[id]; !ok { + transactionsToUpdate[id] = setMetadata.Metadata + } else { + for k, v := range setMetadata.Metadata { + m[k] = v + } + } } - } - if w.monitor != nil { - w.monitor.SavedMetadata(ctx, store.Name(), store.Name(), fmt.Sprint(setMetadata.TargetID), setMetadata.Metadata) - } - case core.RevertedTransactionLogType: - payload := log.Data.(core.RevertedTransactionLogPayload) - if err := store.UpdateTransactionMetadata(ctx, payload.RevertedTransactionID, - core.RevertedMetadata(payload.RevertTransaction.ID)); err != nil { - return errors.Wrap(err, "updating metadata") - } - txVolumeAggregator, err := volumeAggregator.NextTxWithPostings(ctx, payload.RevertTransaction.Postings...) - if err != nil { - return errors.Wrap(err, "aggregating volumes") - } - expandedTx := core.ExpandedTransaction{ - Transaction: payload.RevertTransaction, - PreCommitVolumes: txVolumeAggregator.PreCommitVolumes, - PostCommitVolumes: txVolumeAggregator.PostCommitVolumes, - } - if err := store.InsertTransactions(ctx, expandedTx); err != nil { - return errors.Wrap(err, "inserting transaction") - } + logsData.monitors = append(logsData.monitors, func(ctx context.Context, monitor monitor.Monitor) { + w.monitor.SavedMetadata(ctx, w.store.Name(), w.store.Name(), fmt.Sprint(setMetadata.TargetID), setMetadata.Metadata) + }) + + case core.RevertedTransactionLogType: + payload := log.Data.(core.RevertedTransactionLogPayload) + id := payload.RevertedTransactionID + metadata := core.RevertedMetadata(payload.RevertTransaction.ID) + if m, ok := transactionsToUpdate[id]; !ok { + transactionsToUpdate[id] = metadata + } else { + for k, v := range metadata { + m[k] = v + } + } - if w.monitor != nil { - revertedTx, err := store.GetTransaction(ctx, payload.RevertedTransactionID) + txVolumeAggregator, err := volumeAggregator.NextTxWithPostings(ctx, payload.RevertTransaction.Postings...) if err != nil { - return err + return nil, errors.Wrap(err, "aggregating volumes") + } + + expandedTx := core.ExpandedTransaction{ + Transaction: payload.RevertTransaction, + PreCommitVolumes: txVolumeAggregator.PreCommitVolumes, + PostCommitVolumes: txVolumeAggregator.PostCommitVolumes, + } + logsData.transactionsToInsert = append(logsData.transactionsToInsert, expandedTx) + + revertedTx, err := w.store.GetTransaction(ctx, payload.RevertedTransactionID) + if err != nil { + return nil, err } - w.monitor.RevertedTransaction(ctx, store.Name(), revertedTx, &expandedTx) + + logsData.monitors = append(logsData.monitors, func(ctx context.Context, monitor monitor.Monitor) { + w.monitor.RevertedTransaction(ctx, w.store.Name(), revertedTx, &expandedTx) + }) } } - return err + for account, metadata := range accountsToUpdate { + logsData.accountsToUpdate = append(logsData.accountsToUpdate, core.Account{ + Address: account, + Metadata: metadata, + }) + } + + for transaction, metadata := range transactionsToUpdate { + logsData.transactionsToUpdate = append(logsData.transactionsToUpdate, core.TransactionWithMetadata{ + ID: transaction, + Metadata: metadata, + }) + } + + return logsData, nil } func (w *Worker) QueueLog(ctx context.Context, log *core.LogHolder, store storage.LedgerStore) { select { case <-w.ctx.Done(): - case w.logChan <- logHolder{ - LogHolder: log, - store: store, - }: + case w.writeChannel <- log: } } -func NewWorker(config workerConfig, driver storage.Driver, monitor Monitor) *Worker { +func NewWorker(config WorkerConfig, driver storage.Driver, store storage.LedgerStore, monitor monitor.Monitor) *Worker { return &Worker{ - logChan: make(chan logHolder, config.ChanSize), + pending: make([]*core.LogHolder, 0), + jobs: make(chan []*core.LogHolder), + releasedJob: make(chan struct{}, 1), + writeChannel: make(chan *core.LogHolder, config.ChanSize), + errorChan: make(chan error, 1), stopChan: make(chan chan struct{}), - workerConfig: config, + WorkerConfig: config, + store: store, driver: driver, monitor: monitor, } diff --git a/pkg/ledger/query/worker_test.go b/pkg/ledger/query/worker_test.go index f0200e371..44c198996 100644 --- a/pkg/ledger/query/worker_test.go +++ b/pkg/ledger/query/worker_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/formancehq/ledger/pkg/core" + "github.com/formancehq/ledger/pkg/ledger/monitor" "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" "github.com/formancehq/stack/libs/go-libs/pgtesting" @@ -32,9 +33,9 @@ func TestWorker(t *testing.T) { require.NoError(t, err) require.True(t, modified) - worker := NewWorker(workerConfig{ + worker := NewWorker(WorkerConfig{ ChanSize: 1024, - }, driver, NewNoOpMonitor()) + }, driver, ledgerStore, monitor.NewNoOpMonitor()) go func() { require.NoError(t, worker.Run(context.Background())) }() diff --git a/pkg/ledger/resolver.go b/pkg/ledger/resolver.go index 3a29d0fac..ebf57e7c9 100644 --- a/pkg/ledger/resolver.go +++ b/pkg/ledger/resolver.go @@ -6,18 +6,20 @@ import ( "github.com/formancehq/ledger/pkg/ledger/cache" "github.com/formancehq/ledger/pkg/ledger/lock" + "github.com/formancehq/ledger/pkg/ledger/monitor" "github.com/formancehq/ledger/pkg/ledger/numscript" "github.com/formancehq/ledger/pkg/ledger/query" "github.com/formancehq/ledger/pkg/ledger/runner" "github.com/formancehq/ledger/pkg/storage" + "github.com/formancehq/stack/libs/go-libs/logging" "github.com/pkg/errors" ) type Resolver struct { storageDriver storage.Driver + monitor monitor.Monitor lock sync.RWMutex locker lock.Locker - queryWorker *query.Worker //TODO(gfyrag): add a routine to clean old ledger ledgers map[string]*Ledger compiler *numscript.Compiler @@ -26,14 +28,14 @@ type Resolver struct { func NewResolver( storageDriver storage.Driver, + monitor monitor.Monitor, locker lock.Locker, - queryWorker *query.Worker, allowPastTimestamps bool, ) *Resolver { return &Resolver{ storageDriver: storageDriver, + monitor: monitor, locker: locker, - queryWorker: queryWorker, compiler: numscript.NewCompiler(), ledgers: map[string]*Ledger{}, allowPastTimestamps: allowPastTimestamps, @@ -65,7 +67,20 @@ func (r *Resolver) GetLedger(ctx context.Context, name string) (*Ledger, error) return nil, err } - ledger = New(store, cache, runner, r.locker, r.queryWorker) + queryWorker := query.NewWorker(query.WorkerConfig{ + ChanSize: 1024, + }, r.storageDriver, store, r.monitor) + + go func() { + if err := queryWorker.Run(logging.ContextWithLogger( + context.Background(), + logging.FromContext(ctx), + )); err != nil { + panic(err) + } + }() + + ledger = New(store, cache, runner, r.locker, queryWorker) r.ledgers[name] = ledger } diff --git a/pkg/ledgertesting/storage.go b/pkg/ledgertesting/storage.go index 86381fd0b..7842c8a5c 100644 --- a/pkg/ledgertesting/storage.go +++ b/pkg/ledgertesting/storage.go @@ -17,6 +17,10 @@ func StorageDriver(t pgtesting.TestingT) *sqlstorage.Driver { db, err := sqlstorage.OpenSQLDB(pgServer.ConnString()) require.NoError(t, err) + t.Cleanup(func() { + db.Close() + }) + return sqlstorage.NewDriver("postgres", schema.NewPostgresDB(db)) } diff --git a/pkg/storage/driver.go b/pkg/storage/driver.go index 64429681b..993bfe770 100644 --- a/pkg/storage/driver.go +++ b/pkg/storage/driver.go @@ -26,31 +26,41 @@ type LedgerStore interface { IsInitialized() bool Name() string + RunInTransaction(ctx context.Context, f func(ctx context.Context, store LedgerStore) error) error + + AppendLog(context.Context, *core.Log) error GetNextLogID(ctx context.Context) (uint64, error) ReadLogsStartingFromID(ctx context.Context, id uint64) ([]core.Log, error) UpdateNextLogID(ctx context.Context, id uint64) error + GetLogs(context.Context, *LogsQuery) (api.Cursor[core.Log], error) + GetLastLog(context.Context) (*core.Log, error) + ReadLogWithReference(ctx context.Context, reference string) (*core.Log, error) + ReadLastLogWithType(ctx context.Context, logType ...core.LogType) (*core.Log, error) + InsertTransactions(ctx context.Context, transaction ...core.ExpandedTransaction) error - UpdateAccountMetadata(ctx context.Context, id string, metadata core.Metadata) error UpdateTransactionMetadata(ctx context.Context, id uint64, metadata core.Metadata) error - GetAccountWithVolumes(ctx context.Context, addr string) (*core.AccountWithVolumes, error) - UpdateVolumes(ctx context.Context, volumes core.AccountsAssetsVolumes) error - EnsureAccountExists(ctx context.Context, account string) error + UpdateTransactionsMetadata(ctx context.Context, txs ...core.TransactionWithMetadata) error CountTransactions(context.Context, TransactionsQuery) (uint64, error) GetTransactions(context.Context, TransactionsQuery) (api.Cursor[core.ExpandedTransaction], error) GetTransaction(ctx context.Context, txid uint64) (*core.ExpandedTransaction, error) - GetAccount(ctx context.Context, accountAddress string) (*core.Account, error) - GetAssetsVolumes(ctx context.Context, accountAddress string) (core.AssetsVolumes, error) + + UpdateAccountMetadata(ctx context.Context, id string, metadata core.Metadata) error + UpdateAccountsMetadata(ctx context.Context, accounts []core.Account) error + EnsureAccountExists(ctx context.Context, account string) error + EnsureAccountsExist(ctx context.Context, accounts []string) error CountAccounts(context.Context, AccountsQuery) (uint64, error) GetAccounts(context.Context, AccountsQuery) (api.Cursor[core.Account], error) + GetAccountWithVolumes(ctx context.Context, addr string) (*core.AccountWithVolumes, error) + GetAccount(ctx context.Context, accountAddress string) (*core.Account, error) + + UpdateVolumes(ctx context.Context, volumes ...core.AccountsAssetsVolumes) error + GetAssetsVolumes(ctx context.Context, accountAddress string) (core.AssetsVolumes, error) + GetBalances(context.Context, BalancesQuery) (api.Cursor[core.AccountsBalances], error) GetBalancesAggregated(context.Context, BalancesQuery) (core.AssetsBalances, error) - GetLastLog(context.Context) (*core.Log, error) - GetLogs(context.Context, *LogsQuery) (api.Cursor[core.Log], error) - AppendLog(context.Context, *core.Log) error + GetMigrationsAvailable() ([]core.MigrationInfo, error) GetMigrationsDone(context.Context) ([]core.MigrationInfo, error) - ReadLogWithReference(ctx context.Context, reference string) (*core.Log, error) - ReadLastLogWithType(ctx context.Context, logType ...core.LogType) (*core.Log, error) } type Driver interface { diff --git a/pkg/storage/sqlstorage/ledger/accounts.go b/pkg/storage/sqlstorage/ledger/accounts.go index e14d039b4..96f95f652 100644 --- a/pkg/storage/sqlstorage/ledger/accounts.go +++ b/pkg/storage/sqlstorage/ledger/accounts.go @@ -312,6 +312,27 @@ func (s *Store) EnsureAccountExists(ctx context.Context, account string) error { return sqlerrors.PostgresError(err) } +func (s *Store) EnsureAccountsExist(ctx context.Context, accounts []string) error { + if !s.isInitialized { + return storage.ErrStoreNotInitialized + } + + accs := make([]*Accounts, len(accounts)) + for i, a := range accounts { + accs[i] = &Accounts{ + Address: a, + Metadata: make(map[string]interface{}), + } + } + + _, err := s.schema.NewInsert(accountsTableName). + Model(&accs). + Ignore(). + Exec(ctx) + + return sqlerrors.PostgresError(err) +} + func (s *Store) UpdateAccountMetadata(ctx context.Context, address string, metadata core.Metadata) error { if !s.isInitialized { return storage.ErrStoreNotInitialized @@ -327,6 +348,26 @@ func (s *Store) UpdateAccountMetadata(ctx context.Context, address string, metad On("CONFLICT (address) DO UPDATE"). Set("metadata = accounts.metadata || EXCLUDED.metadata"). Exec(ctx) - return err + return sqlerrors.PostgresError(err) +} + +func (s *Store) UpdateAccountsMetadata(ctx context.Context, accounts []core.Account) error { + if !s.isInitialized { + return storage.ErrStoreNotInitialized + } + + accs := make([]*Accounts, len(accounts)) + for i, a := range accounts { + accs[i] = &Accounts{ + Address: a.Address, + Metadata: a.Metadata, + } + } + _, err := s.schema.NewInsert(accountsTableName). + Model(&accs). + On("CONFLICT (address) DO UPDATE"). + Set("metadata = accounts.metadata || EXCLUDED.metadata"). + Exec(ctx) + return sqlerrors.PostgresError(err) } diff --git a/pkg/storage/sqlstorage/ledger/accounts_test.go b/pkg/storage/sqlstorage/ledger/accounts_test.go index 7fa1b2ee9..c1530e296 100644 --- a/pkg/storage/sqlstorage/ledger/accounts_test.go +++ b/pkg/storage/sqlstorage/ledger/accounts_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" "github.com/formancehq/ledger/pkg/storage/sqlstorage" @@ -79,4 +80,74 @@ func TestAccounts(t *testing.T) { _, err := store.GetAccounts(context.Background(), q) assert.NoError(t, err, "balance operator filter should not fail") }) + + t.Run("success account insertion", func(t *testing.T) { + addr := "test:account" + metadata := core.Metadata(map[string]any{ + "foo": "bar", + }) + + err := store.UpdateAccountMetadata(context.Background(), addr, metadata) + assert.NoError(t, err, "account insertion should not fail") + + account, err := store.GetAccount(context.Background(), addr) + assert.NoError(t, err, "account retrieval should not fail") + + assert.Equal(t, addr, account.Address, "account address should match") + assert.Equal(t, metadata, account.Metadata, "account metadata should match") + }) + + t.Run("success multiple account insertions", func(t *testing.T) { + accounts := []core.Account{ + { + Address: "test:account1", + Metadata: core.Metadata(map[string]any{"foo1": "bar1"}), + }, + { + Address: "test:account2", + Metadata: core.Metadata(map[string]any{"foo2": "bar2"}), + }, + { + Address: "test:account3", + Metadata: core.Metadata(map[string]any{"foo3": "bar3"}), + }, + } + + err := store.UpdateAccountsMetadata(context.Background(), accounts) + assert.NoError(t, err, "account insertion should not fail") + + for _, account := range accounts { + acc, err := store.GetAccount(context.Background(), account.Address) + assert.NoError(t, err, "account retrieval should not fail") + + assert.Equal(t, account.Address, acc.Address, "account address should match") + assert.Equal(t, account.Metadata, acc.Metadata, "account metadata should match") + } + }) + + t.Run("success ensure account exists", func(t *testing.T) { + addr := "test:account:4" + + err := store.EnsureAccountExists(context.Background(), addr) + assert.NoError(t, err, "account insertion should not fail") + + account, err := store.GetAccount(context.Background(), addr) + assert.NoError(t, err, "account retrieval should not fail") + + assert.Equal(t, addr, account.Address, "account address should match") + }) + + t.Run("success ensure mulitple accounts exist", func(t *testing.T) { + addrs := []string{"test:account:4", "test:account:5", "test:account:6"} + + err := store.EnsureAccountsExist(context.Background(), addrs) + assert.NoError(t, err, "account insertion should not fail") + + for _, addr := range addrs { + account, err := store.GetAccount(context.Background(), addr) + assert.NoError(t, err, "account retrieval should not fail") + + assert.Equal(t, addr, account.Address, "account address should match") + } + }) } diff --git a/pkg/storage/sqlstorage/ledger/logs.go b/pkg/storage/sqlstorage/ledger/logs.go index ca6267474..3513151f8 100644 --- a/pkg/storage/sqlstorage/ledger/logs.go +++ b/pkg/storage/sqlstorage/ledger/logs.go @@ -2,6 +2,7 @@ package ledger import ( "context" + "database/sql" "database/sql/driver" "encoding/base64" "encoding/json" @@ -60,7 +61,7 @@ func (s *Store) batchLogs(ctx context.Context, logs []*core.Log) error { return errors.Wrap(err, "reading last log") } - txn, err := s.schema.Begin() + txn, err := s.schema.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } @@ -176,13 +177,13 @@ func (s *Store) GetLogs(ctx context.Context, q *storage.LogsQuery) (api.Cursor[c } defer rows.Close() - rawLogsV2 := []LogsV2{} - err = s.schema.ScanRows(ctx, rows, &rawLogsV2) - if err != nil { - return api.Cursor[core.Log]{}, errors.Wrap(err, "scanning rows") - } + for rows.Next() { + var raw LogsV2 + err = rows.Scan(&raw.ID, &raw.Type, &raw.Hash, &raw.Date, &raw.Data, &raw.Reference) + if err != nil { + return api.Cursor[core.Log]{}, sqlerrors.PostgresError(err) + } - for _, raw := range rawLogsV2 { payload, err := core.HydrateLog(core.LogType(raw.Type), raw.Data) if err != nil { return api.Cursor[core.Log]{}, errors.Wrap(err, "hydrating log") diff --git a/pkg/storage/sqlstorage/ledger/store.go b/pkg/storage/sqlstorage/ledger/store.go index d4bf7f03b..e2f8d3706 100644 --- a/pkg/storage/sqlstorage/ledger/store.go +++ b/pkg/storage/sqlstorage/ledger/store.go @@ -2,6 +2,7 @@ package ledger import ( "context" + "database/sql" "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/storage" @@ -66,6 +67,36 @@ func (s *Store) IsInitialized() bool { return s.isInitialized } +func (s *Store) RunInTransaction(ctx context.Context, f func(ctx context.Context, store storage.LedgerStore) error) error { + tx, err := s.schema.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + + // Create a fake store to use the tx instead of the bun.DB struct + newStore := NewStore( + ctx, + schema.NewSchema(tx.Tx, s.schema.Name()), + s.onClose, + s.onDelete, + ) + + newStore.isInitialized = s.isInitialized + + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + err = f(ctx, newStore) + if err != nil { + return err + } + + return tx.Commit() +} + func NewStore( ctx context.Context, schema schema.Schema, diff --git a/pkg/storage/sqlstorage/ledger/transactions.go b/pkg/storage/sqlstorage/ledger/transactions.go index f65245d56..112a55c0c 100644 --- a/pkg/storage/sqlstorage/ledger/transactions.go +++ b/pkg/storage/sqlstorage/ledger/transactions.go @@ -408,3 +408,33 @@ func (s *Store) UpdateTransactionMetadata(ctx context.Context, id uint64, metada return sqlerrors.PostgresError(err) } + +func (s *Store) UpdateTransactionsMetadata(ctx context.Context, transactionsWithMetadata ...core.TransactionWithMetadata) error { + if !s.isInitialized { + return storage.ErrStoreNotInitialized + } + + txs := make([]*Transactions, 0, len(transactionsWithMetadata)) + for _, tx := range transactionsWithMetadata { + metadataData, err := json.Marshal(tx.Metadata) + if err != nil { + return err + } + + txs = append(txs, &Transactions{ + ID: tx.ID, + Metadata: metadataData, + }) + } + + values := s.schema.NewValues(&txs) + + _, err := s.schema.NewUpdate(TransactionsTableName). + With("_data", values). + Model((*Transactions)(nil)). + TableExpr("_data"). + Set("metadata = transactions.metadata || _data.metadata"). + Where(fmt.Sprintf("%s.id = _data.id", TransactionsTableName)). + Exec(ctx) + return err +} diff --git a/pkg/storage/sqlstorage/ledger/transactions_test.go b/pkg/storage/sqlstorage/ledger/transactions_test.go new file mode 100644 index 000000000..9e0a6a2de --- /dev/null +++ b/pkg/storage/sqlstorage/ledger/transactions_test.go @@ -0,0 +1,147 @@ +package ledger_test + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/formancehq/ledger/pkg/core" + "github.com/formancehq/ledger/pkg/ledgertesting" + "github.com/formancehq/ledger/pkg/storage" + "github.com/formancehq/ledger/pkg/storage/sqlstorage" + "github.com/stretchr/testify/assert" +) + +func TestTransactions(t *testing.T) { + d := ledgertesting.StorageDriver(t) + + assert.NoError(t, d.Initialize(context.Background())) + + defer func(d *sqlstorage.Driver, ctx context.Context) { + assert.NoError(t, d.Close(ctx)) + }(d, context.Background()) + + store, _, err := d.GetLedgerStore(context.Background(), "foo", true) + assert.NoError(t, err) + + _, err = store.Initialize(context.Background()) + assert.NoError(t, err) + + t.Run("success inserting transaction", func(t *testing.T) { + tx1 := core.ExpandedTransaction{ + Transaction: core.Transaction{ + ID: 0, + TransactionData: core.TransactionData{ + Postings: core.Postings{ + { + Source: "world", + Destination: "alice", + Amount: big.NewInt(100), + Asset: "USD", + }, + }, + Timestamp: now.Add(-3 * time.Hour), + Metadata: core.Metadata{}, + }, + }, + } + + err := store.InsertTransactions(context.Background(), tx1) + assert.NoError(t, err, "inserting transaction should not fail") + + tx, err := store.GetTransaction(context.Background(), 0) + assert.NoError(t, err, "getting transaction should not fail") + assert.Equal(t, &tx1, tx, "transaction should be equal") + }) + + t.Run("success inserting multiple transactions", func(t *testing.T) { + tx2 := core.ExpandedTransaction{ + Transaction: core.Transaction{ + ID: 1, + TransactionData: core.TransactionData{ + Postings: core.Postings{ + { + Source: "world", + Destination: "polo", + Amount: big.NewInt(200), + Asset: "USD", + }, + }, + Timestamp: now.Add(-2 * time.Hour), + Metadata: core.Metadata{}, + }, + }, + } + + tx3 := core.ExpandedTransaction{ + Transaction: core.Transaction{ + ID: 2, + TransactionData: core.TransactionData{ + Postings: core.Postings{ + { + Source: "world", + Destination: "gfyrag", + Amount: big.NewInt(150), + Asset: "USD", + }, + }, + Timestamp: now.Add(-1 * time.Hour), + Metadata: core.Metadata{}, + }, + }, + } + + err := store.InsertTransactions(context.Background(), tx2, tx3) + assert.NoError(t, err, "inserting multiple transactions should not fail") + + tx, err := store.GetTransaction(context.Background(), 1) + assert.NoError(t, err, "getting transaction should not fail") + assert.Equal(t, &tx2, tx, "transaction should be equal") + + tx, err = store.GetTransaction(context.Background(), 2) + assert.NoError(t, err, "getting transaction should not fail") + assert.Equal(t, &tx3, tx, "transaction should be equal") + }) + + t.Run("success counting transactions", func(t *testing.T) { + count, err := store.CountTransactions(context.Background(), storage.TransactionsQuery{}) + assert.NoError(t, err, "counting transactions should not fail") + assert.Equal(t, uint64(3), count, "count should be equal") + }) + + t.Run("success updating transaction metadata", func(t *testing.T) { + metadata := core.Metadata(map[string]any{ + "foo": "bar", + }) + err := store.UpdateTransactionMetadata(context.Background(), 0, metadata) + assert.NoError(t, err, "updating transaction metadata should not fail") + + tx, err := store.GetTransaction(context.Background(), 0) + assert.NoError(t, err, "getting transaction should not fail") + assert.Equal(t, tx.Metadata, metadata, "metadata should be equal") + }) + + t.Run("success updating multiple transaction metadata", func(t *testing.T) { + txToUpdate1 := core.TransactionWithMetadata{ + ID: 1, + Metadata: core.Metadata(map[string]any{"foo1": "bar2"}), + } + txToUpdate2 := core.TransactionWithMetadata{ + ID: 2, + Metadata: core.Metadata(map[string]any{"foo2": "bar2"}), + } + txs := []core.TransactionWithMetadata{txToUpdate1, txToUpdate2} + + err := store.UpdateTransactionsMetadata(context.Background(), txs...) + assert.NoError(t, err, "updating multiple transaction metadata should not fail") + + tx, err := store.GetTransaction(context.Background(), 1) + assert.NoError(t, err, "getting transaction should not fail") + assert.Equal(t, tx.Metadata, txToUpdate1.Metadata, "metadata should be equal") + + tx, err = store.GetTransaction(context.Background(), 2) + assert.NoError(t, err, "getting transaction should not fail") + assert.Equal(t, tx.Metadata, txToUpdate2.Metadata, "metadata should be equal") + }) +} diff --git a/pkg/storage/sqlstorage/ledger/volumes.go b/pkg/storage/sqlstorage/ledger/volumes.go index 604beaea7..0e081219e 100644 --- a/pkg/storage/sqlstorage/ledger/volumes.go +++ b/pkg/storage/sqlstorage/ledger/volumes.go @@ -23,31 +23,40 @@ type Volumes struct { Output uint64 `bun:"output,type:numeric"` } -func (s *Store) UpdateVolumes(ctx context.Context, volumes core.AccountsAssetsVolumes) error { +func (s *Store) UpdateVolumes(ctx context.Context, volumes ...core.AccountsAssetsVolumes) error { if !s.isInitialized { return storage.ErrStoreNotInitialized } - for account, accountVolumes := range volumes { - for asset, volumes := range accountVolumes { - v := &Volumes{ - Account: account, - Asset: asset, - Input: volumes.Input.Uint64(), - Output: volumes.Output.Uint64(), + volumesMap := make(map[string]*Volumes) + for _, vs := range volumes { + for account, accountVolumes := range vs { + for asset, volumes := range accountVolumes { + // De-duplicate same volumes to only have the last version + volumesMap[account+asset] = &Volumes{ + Account: account, + Asset: asset, + Input: volumes.Input.Uint64(), + Output: volumes.Output.Uint64(), + } } + } + } - query := s.schema.NewInsert(volumesTableName). - Model(v). - On("CONFLICT (account, asset) DO UPDATE"). - Set("input = EXCLUDED.input, output = EXCLUDED.output"). - String() + vls := make([]*Volumes, 0, len(volumes)) + for _, v := range volumesMap { + vls = append(vls, v) + } - _, err := s.schema.ExecContext(ctx, query) - if err != nil { - return sqlerrors.PostgresError(err) - } - } + query := s.schema.NewInsert(volumesTableName). + Model(&vls). + On("CONFLICT (account, asset) DO UPDATE"). + Set("input = EXCLUDED.input, output = EXCLUDED.output"). + String() + + _, err := s.schema.ExecContext(ctx, query) + if err != nil { + return sqlerrors.PostgresError(err) } return nil diff --git a/pkg/storage/sqlstorage/ledger/volumes_test.go b/pkg/storage/sqlstorage/ledger/volumes_test.go new file mode 100644 index 000000000..c90c02ade --- /dev/null +++ b/pkg/storage/sqlstorage/ledger/volumes_test.go @@ -0,0 +1,92 @@ +package ledger_test + +import ( + "context" + "math/big" + "testing" + + "github.com/formancehq/ledger/pkg/core" + "github.com/formancehq/ledger/pkg/ledgertesting" + "github.com/formancehq/ledger/pkg/storage/sqlstorage" + "github.com/stretchr/testify/assert" +) + +func TestVolumes(t *testing.T) { + d := ledgertesting.StorageDriver(t) + + assert.NoError(t, d.Initialize(context.Background())) + + defer func(d *sqlstorage.Driver, ctx context.Context) { + assert.NoError(t, d.Close(ctx)) + }(d, context.Background()) + + store, _, err := d.GetLedgerStore(context.Background(), "foo", true) + assert.NoError(t, err) + + _, err = store.Initialize(context.Background()) + assert.NoError(t, err) + + t.Run("success update volumes", func(t *testing.T) { + foo := core.AssetsVolumes{ + "bar": { + Input: big.NewInt(1), + Output: big.NewInt(2), + }, + } + + foo2 := core.AssetsVolumes{ + "bar2": { + Input: big.NewInt(3), + Output: big.NewInt(4), + }, + } + + volumes := core.AccountsAssetsVolumes{ + "foo": foo, + "foo2": foo2, + } + + err := store.UpdateVolumes(context.Background(), volumes) + assert.NoError(t, err, "update volumes should not fail") + + assetVolumes, err := store.GetAssetsVolumes(context.Background(), "foo") + assert.NoError(t, err, "get asset volumes should not fail") + assert.Equal(t, foo, assetVolumes, "asset volumes should be equal") + + assetVolumes, err = store.GetAssetsVolumes(context.Background(), "foo2") + assert.NoError(t, err, "get asset volumes should not fail") + assert.Equal(t, foo2, assetVolumes, "asset volumes should be equal") + }) + + t.Run("success update same volume", func(t *testing.T) { + foo := core.AssetsVolumes{ + "bar": { + Input: big.NewInt(1), + Output: big.NewInt(2), + }, + } + + foo2 := core.AssetsVolumes{ + "bar": { + Input: big.NewInt(3), + Output: big.NewInt(4), + }, + } + + volumes := []core.AccountsAssetsVolumes{ + { + "foo": foo, + }, + { + "foo": foo2, + }, + } + + err := store.UpdateVolumes(context.Background(), volumes...) + assert.NoError(t, err, "update volumes should not fail") + + assetVolumes, err := store.GetAssetsVolumes(context.Background(), "foo") + assert.NoError(t, err, "get asset volumes should not fail") + assert.Equal(t, foo2, assetVolumes, "asset volumes should be equal") + }) +} diff --git a/pkg/storage/sqlstorage/schema/schema.go b/pkg/storage/sqlstorage/schema/schema.go index 6b547ee26..782c75051 100644 --- a/pkg/storage/sqlstorage/schema/schema.go +++ b/pkg/storage/sqlstorage/schema/schema.go @@ -9,10 +9,17 @@ import ( ) type Schema struct { - *bun.DB + bun.IDB name string } +func NewSchema(db bun.IDB, name string) Schema { + return Schema{ + IDB: db, + name: name, + } +} + func (s *Schema) Name() string { return s.name } @@ -40,7 +47,7 @@ func (s *Schema) Delete(ctx context.Context) error { } func (s *Schema) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { - bunTx, err := s.DB.BeginTx(ctx, opts) + bunTx, err := s.IDB.BeginTx(ctx, opts) if err != nil { return nil, err } @@ -55,30 +62,30 @@ func (s *Schema) Flavor() string { } func (s *Schema) Close(ctx context.Context) error { - // Do not close the DB, it is shared with other schemas + // Do not close the IDB, it is shared with other schemas return nil } // Override all bun methods to use the schema name func (s *Schema) NewInsert(tableName string) *bun.InsertQuery { - return s.DB.NewInsert().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) + return s.IDB.NewInsert().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) } func (s *Schema) NewUpdate(tableName string) *bun.UpdateQuery { - return s.DB.NewUpdate().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) + return s.IDB.NewUpdate().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) } func (s *Schema) NewSelect(tableName string) *bun.SelectQuery { - return s.DB.NewSelect().ModelTableExpr("?0.?1 as ?1", bun.Ident(s.Name()), bun.Ident(tableName)) + return s.IDB.NewSelect().ModelTableExpr("?0.?1 as ?1", bun.Ident(s.Name()), bun.Ident(tableName)) } func (s *Schema) NewCreateTable(tableName string) *bun.CreateTableQuery { - return s.DB.NewCreateTable().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) + return s.IDB.NewCreateTable().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) } func (s *Schema) NewDelete(tableName string) *bun.DeleteQuery { - return s.DB.NewDelete().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) + return s.IDB.NewDelete().ModelTableExpr("?0.?1", bun.Ident(s.Name()), bun.Ident(tableName)) } type DB interface { @@ -105,7 +112,7 @@ func (p *postgresDB) Initialize(ctx context.Context) error { func (p *postgresDB) Schema(ctx context.Context, name string) (Schema, error) { return Schema{ - DB: p.db, + IDB: p.db, name: name, }, nil }