diff --git a/internal/storage/ledger/balances.go b/internal/storage/ledger/balances.go index 91a6de3c4..566e45d76 100644 --- a/internal/storage/ledger/balances.go +++ b/internal/storage/ledger/balances.go @@ -187,7 +187,6 @@ func (s *Store) GetAggregatedBalances(ctx context.Context, q ledgercontroller.Ge return aggregatedVolumes.Aggregated.Balances(), nil } -// todo: need to handle previous version schema by looking moves func (s *Store) GetBalances(ctx context.Context, query ledgercontroller.BalanceQuery) (ledgercontroller.Balances, error) { return tracing.TraceWithMetric( ctx, @@ -224,25 +223,62 @@ func (s *Store) GetBalances(ctx context.Context, query ledgercontroller.BalanceQ } } - err := s.db.NewSelect(). - With( - "ins", - // Try to insert volumes with 0 values. - // This way, if the account has a 0 balance at this point, it will be locked as any other accounts. - // It the complete sql transaction fail, the account volumes will not be inserted. - s.db.NewInsert(). - Model(&accountsVolumes). - ModelTableExpr(s.GetPrefixedRelationName("accounts_volumes")). - On("conflict do nothing"), - ). - Model(&accountsVolumes). + // Try to insert volumes using last move (to keep compat with previous version) or 0 values. + // This way, if the account has a 0 balance at this point, it will be locked as any other accounts. + // If the complete sql transaction fails, the account volumes will not be inserted. + selectMoves := s.db.NewSelect(). + ModelTableExpr(s.GetPrefixedRelationName("moves")). + DistinctOn("accounts_address, asset"). + Column("accounts_address", "asset"). + ColumnExpr("first_value(post_commit_volumes) over (partition by accounts_address, asset order by seq desc) as post_commit_volumes"). + ColumnExpr("first_value(ledger) over (partition by accounts_address, asset order by seq desc) as ledger"). + Where("("+strings.Join(conditions, ") OR (")+")", args...) + + zeroValuesAndMoves := s.db.NewSelect(). + TableExpr("(?) data", selectMoves). + Column("ledger", "accounts_address", "asset"). + ColumnExpr("(post_commit_volumes).inputs as input"). + ColumnExpr("(post_commit_volumes).outputs as output"). + UnionAll( + s.db.NewSelect(). + TableExpr( + "(?) data", + s.db.NewSelect().NewValues(&accountsVolumes), + ). + Column("*"), + ) + + zeroValueOrMoves := s.db.NewSelect(). + TableExpr("(?) data", zeroValuesAndMoves). + Column("ledger", "accounts_address", "asset", "input", "output"). + DistinctOn("ledger, accounts_address, asset") + + insertDefaultValue := s.db.NewInsert(). + TableExpr(s.GetPrefixedRelationName("accounts_volumes")). + TableExpr("(" + zeroValueOrMoves.String() + ") data"). + On("conflict (ledger, accounts_address, asset) do nothing"). + Returning("ledger, accounts_address, asset, input, output") + + selectExistingValues := s.db.NewSelect(). ModelTableExpr(s.GetPrefixedRelationName("accounts_volumes")). - Column("accounts_address", "asset", "input", "output"). + Column("ledger", "accounts_address", "asset", "input", "output"). Where("("+strings.Join(conditions, ") OR (")+")", args...). For("update"). // notes(gfyrag): Keep order, it ensures consistent locking order and limit deadlocks - Order("accounts_address", "asset"). - Scan(ctx) + Order("accounts_address", "asset") + + finalQuery := s.db.NewSelect(). + With("inserted", insertDefaultValue). + With("existing", selectExistingValues). + ModelTableExpr( + "(?) accounts_volumes", + s.db.NewSelect(). + ModelTableExpr("inserted"). + UnionAll(s.db.NewSelect().ModelTableExpr("existing")), + ). + Model(&accountsVolumes) + + err := finalQuery.Scan(ctx) if err != nil { return nil, postgres.ResolveError(err) } diff --git a/internal/storage/ledger/balances_test.go b/internal/storage/ledger/balances_test.go index a015814c1..3096e8952 100644 --- a/internal/storage/ledger/balances_test.go +++ b/internal/storage/ledger/balances_test.go @@ -4,6 +4,7 @@ package ledger_test import ( "database/sql" + "github.com/formancehq/go-libs/v2/bun/bunpaginate" "math/big" "testing" @@ -127,6 +128,56 @@ func TestBalancesGet(t *testing.T) { require.NoError(t, err) require.Equal(t, 2, count) }) + + t.Run("with balance from move", func(t *testing.T) { + t.Parallel() + + tx := ledger.NewTransaction().WithPostings( + ledger.NewPosting("world", "bank", "USD", big.NewInt(100)), + ) + err := store.InsertTransaction(ctx, &tx) + require.NoError(t, err) + + bankAccount := ledger.Account{ + Address: "bank", + FirstUsage: tx.InsertedAt, + InsertionDate: tx.InsertedAt, + UpdatedAt: tx.InsertedAt, + } + _, err = store.UpsertAccount(ctx, &bankAccount) + require.NoError(t, err) + + err = store.InsertMoves(ctx, &ledger.Move{ + TransactionID: tx.ID, + IsSource: false, + Account: "bank", + Amount: (*bunpaginate.BigInt)(big.NewInt(100)), + Asset: "USD", + InsertionDate: tx.InsertedAt, + EffectiveDate: tx.InsertedAt, + PostCommitVolumes: pointer.For(ledger.NewVolumesInt64(100, 0)), + }) + require.NoError(t, err) + + balances, err := store.GetBalances(ctx, ledgercontroller.BalanceQuery{ + "bank": {"USD"}, + }) + require.NoError(t, err) + + require.NotNil(t, balances["bank"]) + RequireEqual(t, big.NewInt(100), balances["bank"]["USD"]) + + // Check a new line has been inserted into accounts_volumes table + volumes := &ledger.AccountsVolumes{} + err = store.GetDB().NewSelect(). + ModelTableExpr(store.GetPrefixedRelationName("accounts_volumes")). + Where("accounts_address = ?", "bank"). + Scan(ctx, volumes) + require.NoError(t, err) + + RequireEqual(t, big.NewInt(100), volumes.Input) + RequireEqual(t, big.NewInt(0), volumes.Output) + }) } func TestBalancesAggregates(t *testing.T) {