Skip to content

Commit

Permalink
feat: enforce uniqueness of reference when not fully migrated
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag committed Nov 4, 2024
1 parent 6bddcd1 commit 0b6ab5f
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 18 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ require (
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/containerd/continuity v0.4.3 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dnwe/otelsarama v0.0.0-20240308230250-9388d9d40bc0 // indirect
github.com/docker/cli v27.3.1+incompatible // indirect
Expand Down
4 changes: 2 additions & 2 deletions internal/controller/ledger/store_generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/storage/bucket/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (b *Bucket) Migrate(ctx context.Context, tracer trace.Tracer) error {
return migrate(ctx, tracer, b.db, b.name)
}

func (b *Bucket) IsUpToDate(ctx context.Context) (bool, error) {
func (b *Bucket) HasMinimalVersion(ctx context.Context) (bool, error) {
migrator := GetMigrator(b.db, b.name)
lastVersion, err := migrator.GetLastVersion(ctx)
if err != nil {
Expand Down
32 changes: 32 additions & 0 deletions internal/storage/bucket/migrations/11-make-stateless/up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,35 @@ $do$
end loop;
END
$do$;

-- following index will enforce uniqueness of transaction reference until the appropriate index is full built (see next migration)
create or replace function enforce_reference_uniqueness() returns trigger
security definer
language plpgsql
as
$$
begin
-- Temporary magic number
-- The migration 13 will remove the trigger
perform pg_advisory_xact_lock(9999999);

if exists(
select 1
from transactions
where reference = new.reference
and ledger = new.ledger
and id != new.id
) then
raise exception 'duplicate reference';
end if;

return new;
end
$$ set search_path from current;

create constraint trigger enforce_reference_uniqueness
after insert on transactions
deferrable initially deferred
for each row
when ( new.reference is not null )
execute procedure enforce_reference_uniqueness();
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
set search_path = '{{.Schema}}';

drop trigger enforce_reference_uniqueness on transactions;
drop function enforce_reference_uniqueness();

drop index transactions_reference;
alter index transactions_reference2 rename to transactions_reference;

Expand Down
4 changes: 4 additions & 0 deletions internal/storage/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ func (d *Driver) UpgradeAllBuckets(ctx context.Context) error {
return nil
}

func (d *Driver) GetDB() *bun.DB {
return d.db
}

func New(db *bun.DB, opts ...Option) *Driver {
ret := &Driver{
db: db,
Expand Down
2 changes: 1 addition & 1 deletion internal/storage/ledger/legacy/adapters.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (d *DefaultStoreAdapter) GetVolumesWithBalances(ctx context.Context, q ledg
}

func (d *DefaultStoreAdapter) IsUpToDate(ctx context.Context) (bool, error) {
return d.newStore.IsUpToDate(ctx)
return d.newStore.HasMinimalVersion(ctx)
}

func (d *DefaultStoreAdapter) GetMigrationsInfo(ctx context.Context) ([]migrations.Info, error) {
Expand Down
25 changes: 14 additions & 11 deletions internal/storage/ledger/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ import (
. "github.com/formancehq/go-libs/v2/testing/utils"
"github.com/formancehq/ledger/internal/storage/driver"
ledgerstore "github.com/formancehq/ledger/internal/storage/ledger"
"go.opentelemetry.io/otel/trace/noop"
"math/big"
"os"
"sync/atomic"
"testing"

"github.com/formancehq/go-libs/v2/bun/bundebug"
"github.com/formancehq/go-libs/v2/testing/docker"
ledger "github.com/formancehq/ledger/internal"
"github.com/formancehq/ledger/internal/storage/bucket"
"github.com/google/go-cmp/cmp"

"github.com/uptrace/bun/dialect/pgdialect"
Expand All @@ -33,7 +30,6 @@ import (
var (
srv = NewDeferred[*pgtesting.PostgresServer]()
bunDB = NewDeferred[*bun.DB]()
ledgerCount = atomic.Int64{}
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -68,10 +64,9 @@ type T interface {
Cleanup(func())
}

func newLedgerStore(t T) *ledgerstore.Store {
func newDriver(t T) *driver.Driver {
t.Helper()

ledgerName := uuid.NewString()[:8]
ctx := logging.TestingContext()

Wait(srv, bunDB)
Expand All @@ -88,15 +83,23 @@ func newLedgerStore(t T) *ledgerstore.Store {

require.NoError(t, driver.Migrate(ctx, db))

return driver.New(bunDB.GetValue())
}

func newLedgerStore(t T) *ledgerstore.Store {
t.Helper()

driver := newDriver(t)
ledgerName := uuid.NewString()[:8]
ctx := logging.TestingContext()

l := ledger.MustNewWithDefault(ledgerName)
l.Bucket = ledgerName
l.ID = int(ledgerCount.Add(1))

b := bucket.New(bunDB.GetValue(), ledgerName)
require.NoError(t, b.Migrate(ctx, noop.Tracer{}))
require.NoError(t, b.AddLedger(ctx, l, bunDB.GetValue()))
store, err := driver.CreateLedger(ctx, &l)
require.NoError(t, err)

return ledgerstore.New(bunDB.GetValue(), b, l)
return store
}

func bigIntComparer(v1 *big.Int, v2 *big.Int) bool {
Expand Down
4 changes: 2 additions & 2 deletions internal/storage/ledger/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ func New(db bun.IDB, bucket *bucket.Bucket, ledger ledger.Ledger, opts ...Option
return ret
}

func (s *Store) IsUpToDate(ctx context.Context) (bool, error) {
return s.bucket.IsUpToDate(ctx)
func (s *Store) HasMinimalVersion(ctx context.Context) (bool, error) {
return s.bucket.HasMinimalVersion(ctx)
}

func (s *Store) GetMigrationsInfo(ctx context.Context) ([]migrations.Info, error) {
Expand Down
24 changes: 24 additions & 0 deletions internal/storage/ledger/transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,25 @@ func (s *Store) selectTransactions(date *time.Time, expandVolumes, expandEffecti
}

func (s *Store) CommitTransaction(ctx context.Context, tx *ledger.Transaction) error {

// todo(next-minor): remove that on ledger 2.3 when the corresponding index will be completely built (see migration 12)
//if tx.Reference != "" {
// // Magic number, as long as no other process try to take the same exact lock for another reason, it will be ok.
// // This code will be removed in the next minor by the way.
// _, err := s.db.ExecContext(ctx, `select pg_advisory_xact_lock(99999999999)`)
// if err != nil {
// return err
// }
//
// exists, err := s.db.NewSelect().
// ModelTableExpr(s.GetPrefixedRelationName("transactions")).
// Where("reference = ?", tx.Reference).
// Exists(ctx)
// if exists {
// return ledgercontroller.NewErrTransactionReferenceConflict(tx.Reference)
// }
//}

postCommitVolumes, err := s.UpdateVolumes(ctx, tx.VolumeUpdates()...)
if err != nil {
return fmt.Errorf("failed to update balances: %w", err)
Expand Down Expand Up @@ -400,6 +419,11 @@ func (s *Store) InsertTransaction(ctx context.Context, tx *ledger.Transaction) e
if err.(postgres.ErrConstraintsFailed).GetConstraint() == "transactions_reference" {
return nil, ledgercontroller.NewErrTransactionReferenceConflict(tx.Reference)
}
case errors.Is(err, postgres.ErrRaisedException{}):
// todo(next-minor): remove this test
if err.(postgres.ErrRaisedException).GetMessage() == "duplicate reference" {
return nil, ledgercontroller.NewErrTransactionReferenceConflict(tx.Reference)
}
default:
return nil, err
}
Expand Down
81 changes: 81 additions & 0 deletions internal/storage/ledger/transactions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"database/sql"
"fmt"
"github.com/alitto/pond"
"github.com/formancehq/ledger/internal/storage/bucket"
ledgerstore "github.com/formancehq/ledger/internal/storage/ledger"
"github.com/google/uuid"
"math/big"
"slices"
"testing"
Expand Down Expand Up @@ -603,6 +606,84 @@ func TestTransactionsInsert(t *testing.T) {
require.Error(t, err)
require.True(t, errors.Is(err, ledgercontroller.ErrTransactionReferenceConflict{}))
})
// todo(next-minor): remove this test
t.Run("check reference conflict with minimal store version", func(t *testing.T) {
t.Parallel()

driver := newDriver(t)
ledgerName := uuid.NewString()[:8]

l := ledger.MustNewWithDefault(ledgerName)
l.Bucket = ledgerName

migrator := bucket.GetMigrator(driver.GetDB(), ledgerName)
for i := 0; i < bucket.MinimalSchemaVersion; i++ {
require.NoError(t, migrator.UpByOne(ctx))
}

b := bucket.New(driver.GetDB(), ledgerName)
err := b.AddLedger(ctx, l, driver.GetDB())
require.NoError(t, err)

store := ledgerstore.New(driver.GetDB(), b, l)

const nbTry = 100

for i := 0; i < nbTry; i++ {
errChan := make(chan error, 2)

// Create a simple tx
tx1 := ledger.Transaction{
TransactionData: ledger.TransactionData{
Timestamp: now,
Reference: fmt.Sprintf("foo:%d", i),
Postings: []ledger.Posting{
ledger.NewPosting("world", "bank", "USD/2", big.NewInt(100)),
},
},
}
go func() {
errChan <- store.InsertTransaction(ctx, &tx1)
}()

// Create another tx with the same reference
tx2 := ledger.Transaction{
TransactionData: ledger.TransactionData{
Timestamp: now,
Reference: fmt.Sprintf("foo:%d", i),
Postings: []ledger.Posting{
ledger.NewPosting("world", "bank", "USD/2", big.NewInt(100)),
},
},
}
go func() {
errChan <- store.InsertTransaction(ctx, &tx2)
}()

select {
case err1 := <-errChan:
if err1 != nil {
require.True(t, errors.Is(err1, ledgercontroller.ErrTransactionReferenceConflict{}))
select {
case err2 := <-errChan:
require.NoError(t, err2)
case <-time.After(time.Second):
require.Fail(t, "should have received an error")
}
} else {
select {
case err2 := <-errChan:
require.Error(t, err2)
require.True(t, errors.Is(err2, ledgercontroller.ErrTransactionReferenceConflict{}))
case <-time.After(time.Second):
require.Fail(t, "should have received an error")
}
}
case <-time.After(time.Second):
require.Fail(t, "should have received an error")
}
}
})
t.Run("check denormalization", func(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 0b6ab5f

Please sign in to comment.