Skip to content

Commit

Permalink
feat: make get balances call compat with v2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag committed Oct 19, 2024
1 parent 1b94951 commit f4228fd
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 16 deletions.
68 changes: 52 additions & 16 deletions internal/storage/ledger/balances.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ func (s *Store) GetAggregatedBalances(ctx context.Context, q ledgercontroller.Ge
return aggregatedVolumes.Aggregated.Balances(), nil
}

// todo: need to handle previous version schema by looking moves
func (s *Store) GetBalances(ctx context.Context, query ledgercontroller.BalanceQuery) (ledgercontroller.Balances, error) {
return tracing.TraceWithMetric(
ctx,
Expand Down Expand Up @@ -224,25 +223,62 @@ func (s *Store) GetBalances(ctx context.Context, query ledgercontroller.BalanceQ
}
}

err := s.db.NewSelect().
With(
"ins",
// Try to insert volumes with 0 values.
// This way, if the account has a 0 balance at this point, it will be locked as any other accounts.
// It the complete sql transaction fail, the account volumes will not be inserted.
s.db.NewInsert().
Model(&accountsVolumes).
ModelTableExpr(s.GetPrefixedRelationName("accounts_volumes")).
On("conflict do nothing"),
).
Model(&accountsVolumes).
// Try to insert volumes using last move (to keep compat with previous version) or 0 values.
// This way, if the account has a 0 balance at this point, it will be locked as any other accounts.
// If the complete sql transaction fails, the account volumes will not be inserted.
selectMoves := s.db.NewSelect().
ModelTableExpr(s.GetPrefixedRelationName("moves")).
DistinctOn("accounts_address, asset").
Column("accounts_address", "asset").
ColumnExpr("first_value(post_commit_volumes) over (partition by accounts_address, asset order by seq desc) as post_commit_volumes").
ColumnExpr("first_value(ledger) over (partition by accounts_address, asset order by seq desc) as ledger").
Where("("+strings.Join(conditions, ") OR (")+")", args...)

zeroValuesAndMoves := s.db.NewSelect().
TableExpr("(?) data", selectMoves).
Column("ledger", "accounts_address", "asset").
ColumnExpr("(post_commit_volumes).inputs as input").
ColumnExpr("(post_commit_volumes).outputs as output").
UnionAll(
s.db.NewSelect().
TableExpr(
"(?) data",
s.db.NewSelect().NewValues(&accountsVolumes),
).
Column("*"),
)

zeroValueOrMoves := s.db.NewSelect().
TableExpr("(?) data", zeroValuesAndMoves).
Column("ledger", "accounts_address", "asset", "input", "output").
DistinctOn("ledger, accounts_address, asset")

insertDefaultValue := s.db.NewInsert().
TableExpr(s.GetPrefixedRelationName("accounts_volumes")).
TableExpr("(" + zeroValueOrMoves.String() + ") data").
On("conflict (ledger, accounts_address, asset) do nothing").
Returning("ledger, accounts_address, asset, input, output")

selectExistingValues := s.db.NewSelect().
ModelTableExpr(s.GetPrefixedRelationName("accounts_volumes")).
Column("accounts_address", "asset", "input", "output").
Column("ledger", "accounts_address", "asset", "input", "output").
Where("("+strings.Join(conditions, ") OR (")+")", args...).
For("update").
// notes(gfyrag): Keep order, it ensures consistent locking order and limit deadlocks
Order("accounts_address", "asset").
Scan(ctx)
Order("accounts_address", "asset")

finalQuery := s.db.NewSelect().
With("inserted", insertDefaultValue).
With("existing", selectExistingValues).
ModelTableExpr(
"(?) accounts_volumes",
s.db.NewSelect().
ModelTableExpr("inserted").
UnionAll(s.db.NewSelect().ModelTableExpr("existing")),
).
Model(&accountsVolumes)

err := finalQuery.Scan(ctx)
if err != nil {
return nil, postgres.ResolveError(err)
}
Expand Down
51 changes: 51 additions & 0 deletions internal/storage/ledger/balances_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package ledger_test

import (
"database/sql"
"github.com/formancehq/go-libs/v2/bun/bunpaginate"
"math/big"
"testing"

Expand Down Expand Up @@ -127,6 +128,56 @@ func TestBalancesGet(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 2, count)
})

t.Run("with balance from move", func(t *testing.T) {
t.Parallel()

tx := ledger.NewTransaction().WithPostings(
ledger.NewPosting("world", "bank", "USD", big.NewInt(100)),
)
err := store.InsertTransaction(ctx, &tx)
require.NoError(t, err)

bankAccount := ledger.Account{
Address: "bank",
FirstUsage: tx.InsertedAt,
InsertionDate: tx.InsertedAt,
UpdatedAt: tx.InsertedAt,
}
_, err = store.UpsertAccount(ctx, &bankAccount)
require.NoError(t, err)

err = store.InsertMoves(ctx, &ledger.Move{
TransactionID: tx.ID,
IsSource: false,
Account: "bank",
Amount: (*bunpaginate.BigInt)(big.NewInt(100)),
Asset: "USD",
InsertionDate: tx.InsertedAt,
EffectiveDate: tx.InsertedAt,
PostCommitVolumes: pointer.For(ledger.NewVolumesInt64(100, 0)),
})
require.NoError(t, err)

balances, err := store.GetBalances(ctx, ledgercontroller.BalanceQuery{
"bank": {"USD"},
})
require.NoError(t, err)

require.NotNil(t, balances["bank"])
RequireEqual(t, big.NewInt(100), balances["bank"]["USD"])

// Check a new line has been inserted into accounts_volumes table
volumes := &ledger.AccountsVolumes{}
err = store.GetDB().NewSelect().
ModelTableExpr(store.GetPrefixedRelationName("accounts_volumes")).
Where("accounts_address = ?", "bank").
Scan(ctx, volumes)
require.NoError(t, err)

RequireEqual(t, big.NewInt(100), volumes.Input)
RequireEqual(t, big.NewInt(0), volumes.Output)
})
}

func TestBalancesAggregates(t *testing.T) {
Expand Down

0 comments on commit f4228fd

Please sign in to comment.