diff --git a/internal/controller/ledger/controller_default.go b/internal/controller/ledger/controller_default.go index 084d6c1d4..6d833ba43 100644 --- a/internal/controller/ledger/controller_default.go +++ b/internal/controller/ledger/controller_default.go @@ -54,7 +54,7 @@ func (ctrl *DefaultController) GetMigrationsInfo(ctx context.Context) ([]migrati func (ctrl *DefaultController) runTx(ctx context.Context, parameters Parameters, fn func(sqlTX TX) (*ledger.Log, error)) (*ledger.Log, error) { var log *ledger.Log - err := ctrl.store.WithTX(ctx, func(tx TX) (commit bool, err error) { + err := ctrl.store.WithTX(ctx, nil, func(tx TX) (commit bool, err error) { log, err = fn(tx) if err != nil { return false, err @@ -188,7 +188,7 @@ func (ctrl *DefaultController) Import(ctx context.Context, stream chan ledger.Lo return newErrInvalidState(ledger.StateInitializing, ctrl.ledger.State) } - return ctrl.store.WithTX(ctx, func(sqlTx TX) (bool, error) { + return ctrl.store.WithTX(ctx, nil, func(sqlTx TX) (bool, error) { for log := range stream { switch payload := log.Data.(type) { case ledger.NewTransactionLogPayload: diff --git a/internal/controller/ledger/controller_default_test.go b/internal/controller/ledger/controller_default_test.go index ab1a39d90..18840bbec 100644 --- a/internal/controller/ledger/controller_default_test.go +++ b/internal/controller/ledger/controller_default_test.go @@ -2,6 +2,7 @@ package ledger import ( "context" + "database/sql" "math/big" "testing" @@ -33,8 +34,8 @@ func TestCreateTransaction(t *testing.T) { Return(machine, nil) store.EXPECT(). - WithTX(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, fn func(tx TX) (bool, error)) error { + WithTX(gomock.Any(), nil, gomock.Any()). + DoAndReturn(func(ctx context.Context, _ *sql.TxOptions, fn func(tx TX) (bool, error)) error { _, err := fn(sqlTX) return err }) @@ -77,8 +78,8 @@ func TestRevertTransaction(t *testing.T) { l := NewDefaultController(ledger.Ledger{}, store, listener, machineFactory) store.EXPECT(). - WithTX(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, fn func(tx TX) (bool, error)) error { + WithTX(gomock.Any(), nil, gomock.Any()). + DoAndReturn(func(ctx context.Context, _ *sql.TxOptions, fn func(tx TX) (bool, error)) error { _, err := fn(sqlTX) return err }) @@ -122,8 +123,8 @@ func TestSaveTransactionMetadata(t *testing.T) { l := NewDefaultController(ledger.Ledger{}, store, listener, machineFactory) store.EXPECT(). - WithTX(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, fn func(tx TX) (bool, error)) error { + WithTX(gomock.Any(), nil, gomock.Any()). + DoAndReturn(func(ctx context.Context, _ *sql.TxOptions, fn func(tx TX) (bool, error)) error { _, err := fn(sqlTX) return err }) @@ -160,8 +161,8 @@ func TestDeleteTransactionMetadata(t *testing.T) { l := NewDefaultController(ledger.Ledger{}, store, listener, machineFactory) store.EXPECT(). - WithTX(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, fn func(tx TX) (bool, error)) error { + WithTX(gomock.Any(), nil, gomock.Any()). + DoAndReturn(func(ctx context.Context, _ *sql.TxOptions, fn func(tx TX) (bool, error)) error { _, err := fn(sqlTX) return err }) diff --git a/internal/controller/ledger/store.go b/internal/controller/ledger/store.go index 2f9fd7f78..cfd2198b0 100644 --- a/internal/controller/ledger/store.go +++ b/internal/controller/ledger/store.go @@ -2,6 +2,7 @@ package ledger import ( "context" + "database/sql" "encoding/json" "math/big" @@ -47,7 +48,7 @@ type TX interface { } type Store interface { - WithTX(context.Context, func(TX) (bool, error)) error + WithTX(context.Context, *sql.TxOptions, func(TX) (bool, error)) error GetDB() bun.IDB ListLogs(ctx context.Context, q GetLogsQuery) (*bunpaginate.Cursor[ledger.Log], error) ReadLogWithIdempotencyKey(ctx context.Context, ik string) (*ledger.Log, error) diff --git a/internal/controller/ledger/store_generated.go b/internal/controller/ledger/store_generated.go index dc09e6c49..5fd36bd44 100644 --- a/internal/controller/ledger/store_generated.go +++ b/internal/controller/ledger/store_generated.go @@ -11,6 +11,7 @@ package ledger import ( context "context" + sql "database/sql" reflect "reflect" bunpaginate "github.com/formancehq/go-libs/bun/bunpaginate" @@ -424,15 +425,15 @@ func (mr *MockStoreMockRecorder) ReadLogWithIdempotencyKey(ctx, ik any) *gomock. } // WithTX mocks base method. -func (m *MockStore) WithTX(arg0 context.Context, arg1 func(TX) (bool, error)) error { +func (m *MockStore) WithTX(arg0 context.Context, arg1 *sql.TxOptions, arg2 func(TX) (bool, error)) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTX", arg0, arg1) + ret := m.ctrl.Call(m, "WithTX", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // WithTX indicates an expected call of WithTX. -func (mr *MockStoreMockRecorder) WithTX(arg0, arg1 any) *gomock.Call { +func (mr *MockStoreMockRecorder) WithTX(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTX", reflect.TypeOf((*MockStore)(nil).WithTX), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTX", reflect.TypeOf((*MockStore)(nil).WithTX), arg0, arg1, arg2) } diff --git a/internal/storage/ledger/adapters.go b/internal/storage/ledger/adapters.go index ac60f0bee..504396861 100644 --- a/internal/storage/ledger/adapters.go +++ b/internal/storage/ledger/adapters.go @@ -23,8 +23,12 @@ type DefaultStoreAdapter struct { *Store } -func (d *DefaultStoreAdapter) WithTX(ctx context.Context, f func(ledgercontroller.TX) (bool, error)) error { - tx, err := d.GetDB().BeginTx(ctx, &sql.TxOptions{}) +func (d *DefaultStoreAdapter) WithTX(ctx context.Context, opts *sql.TxOptions, f func(ledgercontroller.TX) (bool, error)) error { + if opts == nil { + opts = &sql.TxOptions{} + } + + tx, err := d.GetDB().BeginTx(ctx, opts) if err != nil { return err }