diff --git a/go.mod b/go.mod index 57507d0fe..b7dc25055 100644 --- a/go.mod +++ b/go.mod @@ -86,7 +86,7 @@ require ( github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/containerd/continuity v0.4.3 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dnwe/otelsarama v0.0.0-20240308230250-9388d9d40bc0 // indirect github.com/docker/cli v27.3.1+incompatible // indirect diff --git a/internal/controller/ledger/store_generated_test.go b/internal/controller/ledger/store_generated_test.go index 1618ae7c3..59479b0da 100644 --- a/internal/controller/ledger/store_generated_test.go +++ b/internal/controller/ledger/store_generated_test.go @@ -364,7 +364,7 @@ func (mr *MockStoreMockRecorder) GetVolumesWithBalances(ctx, q any) *gomock.Call // IsUpToDate mocks base method. func (m *MockStore) IsUpToDate(ctx context.Context) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsUpToDate", ctx) + ret := m.ctrl.Call(m, "HasMinimalVersion", ctx) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 @@ -373,7 +373,7 @@ func (m *MockStore) IsUpToDate(ctx context.Context) (bool, error) { // IsUpToDate indicates an expected call of IsUpToDate. func (mr *MockStoreMockRecorder) IsUpToDate(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsUpToDate", reflect.TypeOf((*MockStore)(nil).IsUpToDate), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasMinimalVersion", reflect.TypeOf((*MockStore)(nil).IsUpToDate), ctx) } // ListAccounts mocks base method. diff --git a/internal/storage/bucket/bucket.go b/internal/storage/bucket/bucket.go index ef1347243..8db0c7dda 100644 --- a/internal/storage/bucket/bucket.go +++ b/internal/storage/bucket/bucket.go @@ -24,7 +24,7 @@ func (b *Bucket) Migrate(ctx context.Context, tracer trace.Tracer) error { return migrate(ctx, tracer, b.db, b.name) } -func (b *Bucket) IsUpToDate(ctx context.Context) (bool, error) { +func (b *Bucket) HasMinimalVersion(ctx context.Context) (bool, error) { migrator := GetMigrator(b.db, b.name) lastVersion, err := migrator.GetLastVersion(ctx) if err != nil { diff --git a/internal/storage/bucket/migrations/11-make-stateless/up.sql b/internal/storage/bucket/migrations/11-make-stateless/up.sql index 9b92f4282..b8d2c9666 100644 --- a/internal/storage/bucket/migrations/11-make-stateless/up.sql +++ b/internal/storage/bucket/migrations/11-make-stateless/up.sql @@ -513,3 +513,35 @@ $do$ end loop; END $do$; + +-- following index will enforce uniqueness of transaction reference until the appropriate index is full built (see next migration) +create or replace function enforce_reference_uniqueness() returns trigger + security definer + language plpgsql +as +$$ +begin + -- Temporary magic number + -- The migration 13 will remove the trigger + perform pg_advisory_xact_lock(9999999); + + if exists( + select 1 + from transactions + where reference = new.reference + and ledger = new.ledger + and id != new.id + ) then + raise exception 'duplicate reference'; + end if; + + return new; +end +$$ set search_path from current; + +create constraint trigger enforce_reference_uniqueness +after insert on transactions +deferrable initially deferred +for each row +when ( new.reference is not null ) +execute procedure enforce_reference_uniqueness(); \ No newline at end of file diff --git a/internal/storage/bucket/migrations/13-create-ledger-indexes/up.sql b/internal/storage/bucket/migrations/13-create-ledger-indexes/up.sql index 1efb24192..5fa90936e 100644 --- a/internal/storage/bucket/migrations/13-create-ledger-indexes/up.sql +++ b/internal/storage/bucket/migrations/13-create-ledger-indexes/up.sql @@ -1,5 +1,8 @@ set search_path = '{{.Schema}}'; +drop trigger enforce_reference_uniqueness on transactions; +drop function enforce_reference_uniqueness(); + drop index transactions_reference; alter index transactions_reference2 rename to transactions_reference; diff --git a/internal/storage/driver/driver.go b/internal/storage/driver/driver.go index 7dd772242..a260fdfdf 100644 --- a/internal/storage/driver/driver.go +++ b/internal/storage/driver/driver.go @@ -216,6 +216,10 @@ func (d *Driver) UpgradeAllBuckets(ctx context.Context) error { return nil } +func (d *Driver) GetDB() *bun.DB { + return d.db +} + func New(db *bun.DB, opts ...Option) *Driver { ret := &Driver{ db: db, diff --git a/internal/storage/ledger/legacy/adapters.go b/internal/storage/ledger/legacy/adapters.go index 3d94a53f9..4fa415d2d 100644 --- a/internal/storage/ledger/legacy/adapters.go +++ b/internal/storage/ledger/legacy/adapters.go @@ -116,7 +116,7 @@ func (d *DefaultStoreAdapter) GetVolumesWithBalances(ctx context.Context, q ledg } func (d *DefaultStoreAdapter) IsUpToDate(ctx context.Context) (bool, error) { - return d.newStore.IsUpToDate(ctx) + return d.newStore.HasMinimalVersion(ctx) } func (d *DefaultStoreAdapter) GetMigrationsInfo(ctx context.Context) ([]migrations.Info, error) { diff --git a/internal/storage/ledger/main_test.go b/internal/storage/ledger/main_test.go index 7670169d9..21b923a44 100644 --- a/internal/storage/ledger/main_test.go +++ b/internal/storage/ledger/main_test.go @@ -8,16 +8,13 @@ import ( . "github.com/formancehq/go-libs/v2/testing/utils" "github.com/formancehq/ledger/internal/storage/driver" ledgerstore "github.com/formancehq/ledger/internal/storage/ledger" - "go.opentelemetry.io/otel/trace/noop" "math/big" "os" - "sync/atomic" "testing" "github.com/formancehq/go-libs/v2/bun/bundebug" "github.com/formancehq/go-libs/v2/testing/docker" ledger "github.com/formancehq/ledger/internal" - "github.com/formancehq/ledger/internal/storage/bucket" "github.com/google/go-cmp/cmp" "github.com/uptrace/bun/dialect/pgdialect" @@ -33,7 +30,6 @@ import ( var ( srv = NewDeferred[*pgtesting.PostgresServer]() bunDB = NewDeferred[*bun.DB]() - ledgerCount = atomic.Int64{} ) func TestMain(m *testing.M) { @@ -68,10 +64,9 @@ type T interface { Cleanup(func()) } -func newLedgerStore(t T) *ledgerstore.Store { +func newDriver(t T) *driver.Driver { t.Helper() - ledgerName := uuid.NewString()[:8] ctx := logging.TestingContext() Wait(srv, bunDB) @@ -88,15 +83,23 @@ func newLedgerStore(t T) *ledgerstore.Store { require.NoError(t, driver.Migrate(ctx, db)) + return driver.New(bunDB.GetValue()) +} + +func newLedgerStore(t T) *ledgerstore.Store { + t.Helper() + + driver := newDriver(t) + ledgerName := uuid.NewString()[:8] + ctx := logging.TestingContext() + l := ledger.MustNewWithDefault(ledgerName) l.Bucket = ledgerName - l.ID = int(ledgerCount.Add(1)) - b := bucket.New(bunDB.GetValue(), ledgerName) - require.NoError(t, b.Migrate(ctx, noop.Tracer{})) - require.NoError(t, b.AddLedger(ctx, l, bunDB.GetValue())) + store, err := driver.CreateLedger(ctx, &l) + require.NoError(t, err) - return ledgerstore.New(bunDB.GetValue(), b, l) + return store } func bigIntComparer(v1 *big.Int, v2 *big.Int) bool { diff --git a/internal/storage/ledger/store.go b/internal/storage/ledger/store.go index d7611c114..26d2f6e05 100644 --- a/internal/storage/ledger/store.go +++ b/internal/storage/ledger/store.go @@ -186,8 +186,8 @@ func New(db bun.IDB, bucket *bucket.Bucket, ledger ledger.Ledger, opts ...Option return ret } -func (s *Store) IsUpToDate(ctx context.Context) (bool, error) { - return s.bucket.IsUpToDate(ctx) +func (s *Store) HasMinimalVersion(ctx context.Context) (bool, error) { + return s.bucket.HasMinimalVersion(ctx) } func (s *Store) GetMigrationsInfo(ctx context.Context) ([]migrations.Info, error) { diff --git a/internal/storage/ledger/transactions.go b/internal/storage/ledger/transactions.go index 801b1ef18..40c93c5fb 100644 --- a/internal/storage/ledger/transactions.go +++ b/internal/storage/ledger/transactions.go @@ -240,6 +240,25 @@ func (s *Store) selectTransactions(date *time.Time, expandVolumes, expandEffecti } func (s *Store) CommitTransaction(ctx context.Context, tx *ledger.Transaction) error { + + // todo(next-minor): remove that on ledger 2.3 when the corresponding index will be completely built (see migration 12) + //if tx.Reference != "" { + // // Magic number, as long as no other process try to take the same exact lock for another reason, it will be ok. + // // This code will be removed in the next minor by the way. + // _, err := s.db.ExecContext(ctx, `select pg_advisory_xact_lock(99999999999)`) + // if err != nil { + // return err + // } + // + // exists, err := s.db.NewSelect(). + // ModelTableExpr(s.GetPrefixedRelationName("transactions")). + // Where("reference = ?", tx.Reference). + // Exists(ctx) + // if exists { + // return ledgercontroller.NewErrTransactionReferenceConflict(tx.Reference) + // } + //} + postCommitVolumes, err := s.UpdateVolumes(ctx, tx.VolumeUpdates()...) if err != nil { return fmt.Errorf("failed to update balances: %w", err) @@ -400,6 +419,11 @@ func (s *Store) InsertTransaction(ctx context.Context, tx *ledger.Transaction) e if err.(postgres.ErrConstraintsFailed).GetConstraint() == "transactions_reference" { return nil, ledgercontroller.NewErrTransactionReferenceConflict(tx.Reference) } + case errors.Is(err, postgres.ErrRaisedException{}): + // todo(next-minor): remove this test + if err.(postgres.ErrRaisedException).GetMessage() == "duplicate reference" { + return nil, ledgercontroller.NewErrTransactionReferenceConflict(tx.Reference) + } default: return nil, err } diff --git a/internal/storage/ledger/transactions_test.go b/internal/storage/ledger/transactions_test.go index 129debaf0..5978f51a0 100644 --- a/internal/storage/ledger/transactions_test.go +++ b/internal/storage/ledger/transactions_test.go @@ -7,6 +7,9 @@ import ( "database/sql" "fmt" "github.com/alitto/pond" + "github.com/formancehq/ledger/internal/storage/bucket" + ledgerstore "github.com/formancehq/ledger/internal/storage/ledger" + "github.com/google/uuid" "math/big" "slices" "testing" @@ -603,6 +606,84 @@ func TestTransactionsInsert(t *testing.T) { require.Error(t, err) require.True(t, errors.Is(err, ledgercontroller.ErrTransactionReferenceConflict{})) }) + // todo(next-minor): remove this test + t.Run("check reference conflict with minimal store version", func(t *testing.T) { + t.Parallel() + + driver := newDriver(t) + ledgerName := uuid.NewString()[:8] + + l := ledger.MustNewWithDefault(ledgerName) + l.Bucket = ledgerName + + migrator := bucket.GetMigrator(driver.GetDB(), ledgerName) + for i := 0; i < bucket.MinimalSchemaVersion; i++ { + require.NoError(t, migrator.UpByOne(ctx)) + } + + b := bucket.New(driver.GetDB(), ledgerName) + err := b.AddLedger(ctx, l, driver.GetDB()) + require.NoError(t, err) + + store := ledgerstore.New(driver.GetDB(), b, l) + + const nbTry = 100 + + for i := 0; i < nbTry; i++ { + errChan := make(chan error, 2) + + // Create a simple tx + tx1 := ledger.Transaction{ + TransactionData: ledger.TransactionData{ + Timestamp: now, + Reference: fmt.Sprintf("foo:%d", i), + Postings: []ledger.Posting{ + ledger.NewPosting("world", "bank", "USD/2", big.NewInt(100)), + }, + }, + } + go func() { + errChan <- store.InsertTransaction(ctx, &tx1) + }() + + // Create another tx with the same reference + tx2 := ledger.Transaction{ + TransactionData: ledger.TransactionData{ + Timestamp: now, + Reference: fmt.Sprintf("foo:%d", i), + Postings: []ledger.Posting{ + ledger.NewPosting("world", "bank", "USD/2", big.NewInt(100)), + }, + }, + } + go func() { + errChan <- store.InsertTransaction(ctx, &tx2) + }() + + select { + case err1 := <-errChan: + if err1 != nil { + require.True(t, errors.Is(err1, ledgercontroller.ErrTransactionReferenceConflict{})) + select { + case err2 := <-errChan: + require.NoError(t, err2) + case <-time.After(time.Second): + require.Fail(t, "should have received an error") + } + } else { + select { + case err2 := <-errChan: + require.Error(t, err2) + require.True(t, errors.Is(err2, ledgercontroller.ErrTransactionReferenceConflict{})) + case <-time.After(time.Second): + require.Fail(t, "should have received an error") + } + } + case <-time.After(time.Second): + require.Fail(t, "should have received an error") + } + } + }) t.Run("check denormalization", func(t *testing.T) { t.Parallel()