diff --git a/cmd/internal/analytics.go b/cmd/internal/analytics.go index 61f07b3c6..153cfcf8e 100644 --- a/cmd/internal/analytics.go +++ b/cmd/internal/analytics.go @@ -1,7 +1,6 @@ package internal import ( - "context" "time" "github.com/Masterminds/semver/v3" @@ -28,17 +27,7 @@ func InitAnalyticsFlags(cmd *cobra.Command, defaultWriteKey string) { func NewAnalyticsModule(v *viper.Viper, version string) fx.Option { if v.GetBool(telemetryEnabledFlag) { - applicationId := viper.GetString(telemetryApplicationIdFlag) - var appIdProviderModule fx.Option - if applicationId == "" { - appIdProviderModule = fx.Provide(analytics.FromStorageAppIdProvider) - } else { - appIdProviderModule = fx.Provide(func() analytics.AppIdProvider { - return analytics.AppIdProviderFn(func(ctx context.Context) (string, error) { - return applicationId, nil - }) - }) - } + applicationID := viper.GetString(telemetryApplicationIdFlag) writeKey := viper.GetString(telemetryWriteKeyFlag) interval := viper.GetDuration(telemetryHeartbeatIntervalFlag) if writeKey == "" { @@ -56,10 +45,7 @@ func NewAnalyticsModule(v *viper.Viper, version string) fx.Option { l.Infof("telemetry enabled but version '%s' is not semver, skip", version) }) } else { - return fx.Options( - appIdProviderModule, - analytics.NewHeartbeatModule(version, writeKey, interval), - ) + return analytics.NewHeartbeatModule(version, writeKey, applicationID, interval) } } } diff --git a/cmd/internal/analytics_test.go b/cmd/internal/analytics_test.go index 6a88932ef..5aa65e391 100644 --- a/cmd/internal/analytics_test.go +++ b/cmd/internal/analytics_test.go @@ -2,18 +2,14 @@ package internal import ( "context" - "net/http" "reflect" "testing" "time" - "github.com/formancehq/ledger/pkg/ledgertesting" - "github.com/formancehq/ledger/pkg/storage" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/stretchr/testify/require" "go.uber.org/fx" - "gopkg.in/segmentio/analytics-go.v3" ) func TestAnalyticsFlags(t *testing.T) { @@ -80,56 +76,6 @@ func TestAnalyticsFlags(t *testing.T) { } } -func TestAnalyticsModule(t *testing.T) { - v := viper.GetViper() - v.Set(telemetryEnabledFlag, true) - v.Set(telemetryWriteKeyFlag, "XXX") - v.Set(telemetryApplicationIdFlag, "appId") - v.Set(telemetryHeartbeatIntervalFlag, 10*time.Second) - - handled := make(chan struct{}) - - module := NewAnalyticsModule(v, "1.0.0") - app := fx.New( - module, - fx.NopLogger, - fx.Provide(func(lc fx.Lifecycle) (storage.Driver, error) { - driver := ledgertesting.StorageDriver(t) - lc.Append(fx.Hook{ - OnStart: driver.Initialize, - OnStop: func(ctx context.Context) error { - return driver.Close(ctx) - }, - }) - return driver, nil - }), - fx.Replace(analytics.Config{ - BatchSize: 1, - Transport: roundTripperFn(func(req *http.Request) (*http.Response, error) { - select { - case <-handled: - // Nothing to do, the chan has already been closed - default: - close(handled) - } - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }), - })) - require.NoError(t, app.Start(context.Background())) - defer func() { - require.NoError(t, app.Stop(context.Background())) - }() - - select { - case <-time.After(time.Second): - require.Fail(t, "Timeout waiting first stats from analytics module") - case <-handled: - } - -} - func TestAnalyticsModuleDisabled(t *testing.T) { v := viper.GetViper() v.Set(telemetryEnabledFlag, false) diff --git a/cmd/internal/utils.go b/cmd/internal/utils.go index b5308412a..c886ed8b4 100644 --- a/cmd/internal/utils.go +++ b/cmd/internal/utils.go @@ -1,7 +1,6 @@ package internal import ( - "net/http" "os" "strings" ) @@ -14,9 +13,3 @@ func setEnvVar(key, value string) func() { os.Setenv(flag, oldEnv) } } - -type roundTripperFn func(req *http.Request) (*http.Response, error) - -func (fn roundTripperFn) RoundTrip(req *http.Request) (*http.Response, error) { - return fn(req) -} diff --git a/go.mod b/go.mod index 10ee23034..be2fd0295 100755 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/formancehq/stack/libs/go-libs v0.0.0-20230222164357-55840b21a337 github.com/go-chi/chi/v5 v5.0.8 github.com/go-chi/cors v1.2.1 + github.com/golang/mock v1.4.4 github.com/google/uuid v1.3.0 github.com/jackc/pgx/v5 v5.3.0 github.com/lib/pq v1.10.7 diff --git a/go.sum b/go.sum index eabb26f27..4cc0dcfbd 100644 --- a/go.sum +++ b/go.sum @@ -180,6 +180,7 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/pkg/analytics/analytics.go b/pkg/analytics/analytics.go new file mode 100644 index 000000000..73de89da0 --- /dev/null +++ b/pkg/analytics/analytics.go @@ -0,0 +1,148 @@ +package analytics + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "runtime" + "time" + + "github.com/formancehq/ledger/pkg/core" + "github.com/formancehq/stack/libs/go-libs/logging" + "github.com/pbnjay/memory" + "gopkg.in/segmentio/analytics-go.v3" +) + +const ( + ApplicationStats = "Application stats" + + VersionProperty = "version" + AccountsProperty = "accounts" + TransactionsProperty = "transactions" + LedgersProperty = "ledgers" + OSProperty = "os" + ArchProperty = "arch" + TimeZoneProperty = "tz" + CPUCountProperty = "cpuCount" + TotalMemoryProperty = "totalMemory" +) + +type heartbeat struct { + version string + interval time.Duration + client analytics.Client + stopChan chan chan struct{} + backend Backend +} + +func (m *heartbeat) Run(ctx context.Context) error { + + enqueue := func() { + err := m.enqueue(ctx) + if err != nil { + logging.FromContext(ctx).WithFields(map[string]interface{}{ + "error": err, + }).Error("enqueuing analytics") + } + } + + enqueue() + for { + select { + case ch := <-m.stopChan: + ch <- struct{}{} + return nil + case <-ctx.Done(): + return ctx.Err() + case <-time.After(m.interval): + enqueue() + } + } +} + +func (m *heartbeat) Stop(ctx context.Context) error { + ch := make(chan struct{}) + m.stopChan <- ch + select { + case <-ctx.Done(): + return ctx.Err() + case <-ch: + return nil + } +} + +func (m *heartbeat) enqueue(ctx context.Context) error { + + appID, err := m.backend.AppID(ctx) + if err != nil { + return err + } + + tz, _ := core.Now().Local().Zone() + + properties := analytics.NewProperties(). + Set(VersionProperty, m.version). + Set(OSProperty, runtime.GOOS). + Set(ArchProperty, runtime.GOARCH). + Set(TimeZoneProperty, tz). + Set(CPUCountProperty, runtime.NumCPU()). + Set(TotalMemoryProperty, memory.TotalMemory()/1024/1024) + + ledgers, err := m.backend.ListLedgers(ctx) + if err != nil { + return err + } + + ledgersProperty := map[string]any{} + + for _, l := range ledgers { + stats := map[string]any{} + if err := func() error { + store, _, err := m.backend.GetLedgerStore(ctx, l, false) + if err != nil { + return err + } + + transactions, err := store.CountTransactions(ctx) + if err != nil { + return err + } + + accounts, err := store.CountAccounts(ctx) + if err != nil { + return err + } + stats[TransactionsProperty] = transactions + stats[AccountsProperty] = accounts + + return nil + }(); err != nil { + return err + } + + digest := sha256.New() + digest.Write([]byte(l)) + ledgerHash := base64.RawURLEncoding.EncodeToString(digest.Sum(nil)) + + ledgersProperty[ledgerHash] = stats + } + if len(ledgersProperty) > 0 { + properties.Set(LedgersProperty, ledgersProperty) + } + + return m.client.Enqueue(&analytics.Track{ + AnonymousId: appID, + Event: ApplicationStats, + Properties: properties, + }) +} + +func newHeartbeat(backend Backend, client analytics.Client, version string, interval time.Duration) *heartbeat { + return &heartbeat{ + version: version, + interval: interval, + client: client, + backend: backend, + stopChan: make(chan chan struct{}, 1), + } +} diff --git a/pkg/analytics/segment_test.go b/pkg/analytics/analytics_test.go similarity index 52% rename from pkg/analytics/segment_test.go rename to pkg/analytics/analytics_test.go index 8ddb03a27..1552c0672 100644 --- a/pkg/analytics/segment_test.go +++ b/pkg/analytics/analytics_test.go @@ -11,11 +11,8 @@ import ( "testing" "time" - "github.com/formancehq/ledger/pkg/ledgertesting" - "github.com/formancehq/ledger/pkg/storage" - "github.com/formancehq/stack/libs/go-libs/pgtesting" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" - "go.uber.org/fx" "gopkg.in/segmentio/analytics-go.v3" ) @@ -76,28 +73,6 @@ const ( writeKey = "key" ) -func module(t pgtesting.TestingT) fx.Option { - return fx.Options( - NewHeartbeatModule(version, writeKey, interval), - fx.NopLogger, - fx.Provide(func() AppIdProvider { - return AppIdProviderFn(func(ctx context.Context) (string, error) { - return "foo", nil - }) - }), - fx.Provide(func(lc fx.Lifecycle) (storage.Driver, error) { - driver := ledgertesting.StorageDriver(t) - lc.Append(fx.Hook{ - OnStart: driver.Initialize, - OnStop: func(ctx context.Context) error { - return driver.Close(ctx) - }, - }) - return driver, nil - }), - ) -} - func EventuallyQueueNotEmpty[ITEM any](t *testing.T, queue *Queue[ITEM]) { require.Eventually(t, func() bool { return !queue.Empty() @@ -109,30 +84,87 @@ var emptyHttpResponse = &http.Response{ StatusCode: http.StatusOK, } -func newApp(module fx.Option, t transport) *fx.App { - return fx.New(module, fx.Replace(analytics.Config{ - BatchSize: 1, - Transport: t, - })) -} +func TestAnalytics(t *testing.T) { + t.Parallel() -func withApp(t *testing.T, app *fx.App, fn func(t *testing.T)) { - require.NoError(t, app.Start(context.Background())) - defer func() { - require.NoError(t, app.Stop(context.Background())) - }() - fn(t) -} + type testCase struct { + name string + transport http.RoundTripper + } + queue := NewQueue[*http.Request]() + firstCallChan := make(chan struct{}) + testCases := []testCase{ + { + name: "nominal", + transport: transport(func(request *http.Request) (*http.Response, error) { + queue.Put(request) + return emptyHttpResponse, nil + }), + }, + { + name: "with error on backend", + transport: transport(func(request *http.Request) (*http.Response, error) { + select { + case <-firstCallChan: // Enter this case only if the chan is closed + queue.Put(request) + return emptyHttpResponse, nil + default: + close(firstCallChan) + return nil, errors.New("general error") + } + }), + }, + } -func TestSegment(t *testing.T) { + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockLedger := NewMockLedger(ctrl) + backend := NewMockBackend(ctrl) + backend. + EXPECT(). + ListLedgers(gomock.Any()). + AnyTimes(). + Return([]string{"default"}, nil) + backend. + EXPECT(). + AppID(gomock.Any()). + AnyTimes(). + Return(applicationId, nil) + backend. + EXPECT(). + GetLedgerStore(gomock.Any(), "default", false). + AnyTimes(). + Return(mockLedger, false, nil) + t.Cleanup(func() { + ctrl.Finish() + }) + analyticsClient, err := analytics.NewWithConfig(writeKey, analytics.Config{ + BatchSize: 1, + Transport: testCase.transport, + }) + require.NoError(t, err) + + mockLedger. + EXPECT(). + CountTransactions(gomock.Any()). + AnyTimes(). + Return(uint64(10), nil) + mockLedger. + EXPECT(). + CountAccounts(gomock.Any()). + AnyTimes(). + Return(uint64(20), nil) + + h := newHeartbeat(backend, analyticsClient, version, interval) + go func() { + require.NoError(t, h.Run(context.Background())) + }() + t.Cleanup(func() { + require.NoError(t, h.Stop(context.Background())) + }) - t.Run("Nominal case", func(t *testing.T) { - queue := NewQueue[*http.Request]() - app := newApp(module(t), func(request *http.Request) (*http.Response, error) { - queue.Put(request) - return emptyHttpResponse, nil - }) - withApp(t, app, func(t *testing.T) { for i := 0; i < 10; i++ { EventuallyQueueNotEmpty(t, queue) request, ok := queue.Get() @@ -153,26 +185,5 @@ func TestSegment(t *testing.T) { require.Equal(t, applicationId, track.AnonymousId) } }) - }) - t.Run("With error on the backend", func(t *testing.T) { - firstCallChan := make(chan struct{}) - - queue := NewQueue[*http.Request]() - app := newApp(module(t), func(request *http.Request) (*http.Response, error) { - select { - case <-firstCallChan: // Enter this case only if the chan is closed - queue.Put(request) - return emptyHttpResponse, nil - default: - close(firstCallChan) - return nil, errors.New("general error") - } - }) - withApp(t, app, func(t *testing.T) { - EventuallyQueueNotEmpty(t, queue) - - _, ok := queue.Get() - require.True(t, ok) - }) - }) + } } diff --git a/pkg/analytics/backend.go b/pkg/analytics/backend.go new file mode 100644 index 000000000..87fa06804 --- /dev/null +++ b/pkg/analytics/backend.go @@ -0,0 +1,81 @@ +package analytics + +import ( + "context" + + "github.com/formancehq/ledger/pkg/storage" + "github.com/google/uuid" + "github.com/pkg/errors" +) + +//go:generate mockgen -source backend.go -destination backend_test.go -package analytics . Ledger + +type Ledger interface { + CountTransactions(ctx context.Context) (uint64, error) + CountAccounts(ctx context.Context) (uint64, error) +} + +type defaultLedger struct { + store storage.LedgerStore +} + +func (d defaultLedger) CountTransactions(ctx context.Context) (uint64, error) { + return d.store.CountTransactions(ctx, *storage.NewTransactionsQuery()) +} + +func (d defaultLedger) CountAccounts(ctx context.Context) (uint64, error) { + return d.store.CountAccounts(ctx, *storage.NewAccountsQuery()) +} + +var _ Ledger = (*defaultLedger)(nil) + +type Backend interface { + AppID(ctx context.Context) (string, error) + ListLedgers(ctx context.Context) ([]string, error) + GetLedgerStore(ctx context.Context, l string, b bool) (Ledger, bool, error) +} + +type defaultBackend struct { + driver storage.Driver + appID string +} + +func (d defaultBackend) AppID(ctx context.Context) (string, error) { + var err error + if d.appID == "" { + d.appID, err = d.driver.GetSystemStore().GetConfiguration(ctx, "appId") + if err != nil && !errors.Is(err, storage.ErrNotFound) { + return "", err + } + if errors.Is(err, storage.ErrNotFound) { + d.appID = uuid.NewString() + if err := d.driver.GetSystemStore().InsertConfiguration(ctx, "appId", d.appID); err != nil { + return "", err + } + } + } + return d.appID, nil +} + +func (d defaultBackend) ListLedgers(ctx context.Context) ([]string, error) { + return d.driver.GetSystemStore().ListLedgers(ctx) +} + +func (d defaultBackend) GetLedgerStore(ctx context.Context, name string, create bool) (Ledger, bool, error) { + ledgerStore, created, err := d.driver.GetLedgerStore(ctx, name, create) + if err != nil { + return nil, false, err + } + return &defaultLedger{ + store: ledgerStore, + }, created, nil +} + +var _ Backend = (*defaultBackend)(nil) + +func newDefaultBackend(driver storage.Driver, appID string) *defaultBackend { + return &defaultBackend{ + driver: driver, + appID: appID, + } +} diff --git a/pkg/analytics/backend_test.go b/pkg/analytics/backend_test.go new file mode 100644 index 000000000..384a475cc --- /dev/null +++ b/pkg/analytics/backend_test.go @@ -0,0 +1,134 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: backend.go + +// Package analytics is a generated GoMock package. +package analytics + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLedger is a mock of Ledger interface. +type MockLedger struct { + ctrl *gomock.Controller + recorder *MockLedgerMockRecorder +} + +// MockLedgerMockRecorder is the mock recorder for MockLedger. +type MockLedgerMockRecorder struct { + mock *MockLedger +} + +// NewMockLedger creates a new mock instance. +func NewMockLedger(ctrl *gomock.Controller) *MockLedger { + mock := &MockLedger{ctrl: ctrl} + mock.recorder = &MockLedgerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLedger) EXPECT() *MockLedgerMockRecorder { + return m.recorder +} + +// CountAccounts mocks base method. +func (m *MockLedger) CountAccounts(ctx context.Context) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAccounts", ctx) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAccounts indicates an expected call of CountAccounts. +func (mr *MockLedgerMockRecorder) CountAccounts(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccounts", reflect.TypeOf((*MockLedger)(nil).CountAccounts), ctx) +} + +// CountTransactions mocks base method. +func (m *MockLedger) CountTransactions(ctx context.Context) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountTransactions", ctx) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountTransactions indicates an expected call of CountTransactions. +func (mr *MockLedgerMockRecorder) CountTransactions(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountTransactions", reflect.TypeOf((*MockLedger)(nil).CountTransactions), ctx) +} + +// MockBackend is a mock of Backend interface. +type MockBackend struct { + ctrl *gomock.Controller + recorder *MockBackendMockRecorder +} + +// MockBackendMockRecorder is the mock recorder for MockBackend. +type MockBackendMockRecorder struct { + mock *MockBackend +} + +// NewMockBackend creates a new mock instance. +func NewMockBackend(ctrl *gomock.Controller) *MockBackend { + mock := &MockBackend{ctrl: ctrl} + mock.recorder = &MockBackendMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBackend) EXPECT() *MockBackendMockRecorder { + return m.recorder +} + +// AppID mocks base method. +func (m *MockBackend) AppID(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AppID", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AppID indicates an expected call of AppID. +func (mr *MockBackendMockRecorder) AppID(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppID", reflect.TypeOf((*MockBackend)(nil).AppID), ctx) +} + +// GetLedgerStore mocks base method. +func (m *MockBackend) GetLedgerStore(ctx context.Context, l string, b bool) (Ledger, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLedgerStore", ctx, l, b) + ret0, _ := ret[0].(Ledger) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetLedgerStore indicates an expected call of GetLedgerStore. +func (mr *MockBackendMockRecorder) GetLedgerStore(ctx, l, b interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLedgerStore", reflect.TypeOf((*MockBackend)(nil).GetLedgerStore), ctx, l, b) +} + +// ListLedgers mocks base method. +func (m *MockBackend) ListLedgers(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListLedgers", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListLedgers indicates an expected call of ListLedgers. +func (mr *MockBackendMockRecorder) ListLedgers(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListLedgers", reflect.TypeOf((*MockBackend)(nil).ListLedgers), ctx) +} diff --git a/pkg/analytics/main_test.go b/pkg/analytics/main_test.go deleted file mode 100644 index 59a0ab5a2..000000000 --- a/pkg/analytics/main_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package analytics - -import ( - "os" - "testing" - - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/formancehq/stack/libs/go-libs/pgtesting" -) - -func TestMain(t *testing.M) { - if err := pgtesting.CreatePostgresServer(); err != nil { - logging.Error(err) - os.Exit(1) - } - code := t.Run() - if err := pgtesting.DestroyPostgresServer(); err != nil { - logging.Error(err) - } - os.Exit(code) -} diff --git a/pkg/analytics/module.go b/pkg/analytics/module.go new file mode 100644 index 000000000..c91507ee8 --- /dev/null +++ b/pkg/analytics/module.go @@ -0,0 +1,48 @@ +package analytics + +import ( + "context" + "time" + + "github.com/formancehq/ledger/pkg/storage" + "go.uber.org/fx" + "gopkg.in/segmentio/analytics-go.v3" +) + +func NewHeartbeatModule(version, writeKey, appID string, interval time.Duration) fx.Option { + return fx.Options( + fx.Supply(analytics.Config{}), // Provide empty config to be able to replace (use fx.Replace) if necessary + fx.Provide(func(cfg analytics.Config) (analytics.Client, error) { + return analytics.NewWithConfig(writeKey, cfg) + }), + fx.Provide(func(client analytics.Client, backend Backend) *heartbeat { + return newHeartbeat(backend, client, version, interval) + }), + fx.Provide(func(driver storage.Driver) Backend { + return newDefaultBackend(driver, appID) + }), + fx.Invoke(func(m *heartbeat, lc fx.Lifecycle) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + go func() { + err := m.Run(context.Background()) + if err != nil { + panic(err) + } + }() + return nil + }, + OnStop: func(ctx context.Context) error { + return m.Stop(ctx) + }, + }) + }), + fx.Invoke(func(lc fx.Lifecycle, client analytics.Client) { + lc.Append(fx.Hook{ + OnStop: func(ctx context.Context) error { + return client.Close() + }, + }) + }), + ) +} diff --git a/pkg/analytics/segment.go b/pkg/analytics/segment.go deleted file mode 100644 index b10e03237..000000000 --- a/pkg/analytics/segment.go +++ /dev/null @@ -1,215 +0,0 @@ -package analytics - -import ( - "context" - "crypto/sha256" - "encoding/base64" - "runtime" - "time" - - "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/storage" - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/pbnjay/memory" - "github.com/pborman/uuid" - "go.uber.org/fx" - "gopkg.in/segmentio/analytics-go.v3" -) - -const ( - ApplicationStats = "Application stats" - - VersionProperty = "version" - AccountsProperty = "accounts" - TransactionsProperty = "transactions" - LedgersProperty = "ledgers" - OSProperty = "os" - ArchProperty = "arch" - TimeZoneProperty = "tz" - CPUCountProperty = "cpuCount" - TotalMemoryProperty = "totalMemory" -) - -type AppIdProvider interface { - AppID(ctx context.Context) (string, error) -} -type AppIdProviderFn func(ctx context.Context) (string, error) - -func (fn AppIdProviderFn) AppID(ctx context.Context) (string, error) { - return fn(ctx) -} - -func FromStorageAppIdProvider(driver storage.Driver) AppIdProvider { - var appId string - return AppIdProviderFn(func(ctx context.Context) (string, error) { - var err error - if appId == "" { - appId, err = driver.GetSystemStore().GetConfiguration(ctx, "appId") - if err != nil && !storage.IsNotFoundError(err) { - return "", err - } - if storage.IsNotFoundError(err) { - appId = uuid.New() - if err := driver.GetSystemStore().InsertConfiguration(ctx, "appId", appId); err != nil { - return "", err - } - } - } - return appId, nil - }) -} - -type heartbeat struct { - version string - interval time.Duration - client analytics.Client - stopChan chan chan struct{} - appIdProvider AppIdProvider - driver storage.Driver -} - -func (m *heartbeat) Run(ctx context.Context) error { - - enqueue := func() { - err := m.enqueue(ctx) - if err != nil { - logging.FromContext(ctx).WithFields(map[string]interface{}{ - "error": err, - }).Error("enqueuing analytics") - } - } - - enqueue() - for { - select { - case ch := <-m.stopChan: - ch <- struct{}{} - return nil - case <-ctx.Done(): - return ctx.Err() - case <-time.After(m.interval): - enqueue() - } - } -} - -func (m *heartbeat) Stop(ctx context.Context) error { - ch := make(chan struct{}) - m.stopChan <- ch - select { - case <-ctx.Done(): - return ctx.Err() - case <-ch: - return nil - } -} - -func (m *heartbeat) enqueue(ctx context.Context) error { - - appId, err := m.appIdProvider.AppID(ctx) - if err != nil { - return err - } - - tz, _ := core.Now().Local().Zone() - - properties := analytics.NewProperties(). - Set(VersionProperty, m.version). - Set(OSProperty, runtime.GOOS). - Set(ArchProperty, runtime.GOARCH). - Set(TimeZoneProperty, tz). - Set(CPUCountProperty, runtime.NumCPU()). - Set(TotalMemoryProperty, memory.TotalMemory()/1024/1024) - - ledgers, err := m.driver.GetSystemStore().ListLedgers(ctx) - if err != nil { - return err - } - - ledgersProperty := map[string]any{} - - for _, l := range ledgers { - stats := map[string]any{} - if err := func() error { - store, _, err := m.driver.GetLedgerStore(ctx, l, false) - if err != nil { - return err - } - transactions, err := store.CountTransactions(ctx, storage.TransactionsQuery{}) - if err != nil { - return err - } - accounts, err := store.CountAccounts(ctx, storage.AccountsQuery{}) - if err != nil { - return err - } - stats[TransactionsProperty] = transactions - stats[AccountsProperty] = accounts - - return nil - }(); err != nil { - return err - } - - digest := sha256.New() - digest.Write([]byte(l)) - ledgerHash := base64.RawURLEncoding.EncodeToString(digest.Sum(nil)) - - ledgersProperty[ledgerHash] = stats - } - if len(ledgersProperty) > 0 { - properties.Set(LedgersProperty, ledgersProperty) - } - - return m.client.Enqueue(&analytics.Track{ - AnonymousId: appId, - Event: ApplicationStats, - Properties: properties, - }) -} - -func newHeartbeat(appIdProvider AppIdProvider, driver storage.Driver, client analytics.Client, version string, interval time.Duration) *heartbeat { - return &heartbeat{ - version: version, - interval: interval, - client: client, - driver: driver, - appIdProvider: appIdProvider, - stopChan: make(chan chan struct{}, 1), - } -} - -func NewHeartbeatModule(version, writeKey string, interval time.Duration) fx.Option { - return fx.Options( - fx.Supply(analytics.Config{}), // Provide empty config to be able to replace (use fx.Replace) if necessary - fx.Provide(func(cfg analytics.Config) (analytics.Client, error) { - return analytics.NewWithConfig(writeKey, cfg) - }), - fx.Provide(func(client analytics.Client, provider AppIdProvider, driver storage.Driver) *heartbeat { - return newHeartbeat(provider, driver, client, version, interval) - }), - fx.Invoke(func(m *heartbeat, lc fx.Lifecycle) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - go func() { - err := m.Run(context.Background()) - if err != nil { - panic(err) - } - }() - return nil - }, - OnStop: func(ctx context.Context) error { - return m.Stop(ctx) - }, - }) - }), - fx.Invoke(func(lc fx.Lifecycle, client analytics.Client) { - lc.Append(fx.Hook{ - OnStop: func(ctx context.Context) error { - return client.Close() - }, - }) - }), - ) -} diff --git a/pkg/api/api.go b/pkg/api/api.go index c9a85aeaa..91b7d8e92 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -3,12 +3,11 @@ package api import ( _ "embed" + "github.com/formancehq/ledger/pkg/api/controllers" "github.com/formancehq/ledger/pkg/api/routes" "github.com/formancehq/ledger/pkg/ledger" "github.com/formancehq/ledger/pkg/storage" "github.com/formancehq/stack/libs/go-libs/health" - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/go-chi/chi/v5" "go.uber.org/fx" ) @@ -18,9 +17,9 @@ type Config struct { func Module(cfg Config) fx.Option { return fx.Options( - fx.Provide(func(storageDriver storage.Driver, resolver *ledger.Resolver, logger logging.Logger, - healthController *health.HealthController) chi.Router { - return routes.NewRouter(storageDriver, cfg.Version, resolver, logger, healthController) + fx.Provide(routes.NewRouter), + fx.Provide(func(storageDriver storage.Driver, resolver *ledger.Resolver) controllers.Backend { + return controllers.NewDefaultBackend(storageDriver, cfg.Version, resolver) }), health.Module(), ) diff --git a/pkg/api/controllers/account_controller.go b/pkg/api/controllers/account_controller.go index afa273e25..ae4704443 100644 --- a/pkg/api/controllers/account_controller.go +++ b/pkg/api/controllers/account_controller.go @@ -117,15 +117,7 @@ func GetAccounts(w http.ResponseWriter, r *http.Request) { func GetAccount(w http.ResponseWriter, r *http.Request) { l := LedgerFromContext(r.Context()) - if !core.ValidateAddress(chi.URLParam(r, "address")) { - apierrors.ResponseError(w, r, errorsutil.NewError(ledger.ErrValidation, - errors.New("invalid account address format"))) - return - } - - acc, err := l.GetAccount( - r.Context(), - chi.URLParam(r, "address")) + acc, err := l.GetAccount(r.Context(), chi.URLParam(r, "address")) if err != nil { apierrors.ResponseError(w, r, err) return diff --git a/pkg/api/controllers/account_controller_test.go b/pkg/api/controllers/account_controller_test.go index 6dcb2a395..949accee6 100644 --- a/pkg/api/controllers/account_controller_test.go +++ b/pkg/api/controllers/account_controller_test.go @@ -1,527 +1,299 @@ package controllers_test import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "math/big" "net/http" + "net/http/httptest" "net/url" "testing" "github.com/formancehq/ledger/pkg/api/apierrors" "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/api/internal" + "github.com/formancehq/ledger/pkg/api/routes" "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/storage" ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" sharedapi "github.com/formancehq/stack/libs/go-libs/api" - "github.com/go-chi/chi/v5" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) func TestGetAccounts(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - require.NoError(t, store.EnsureAccountExists(context.Background(), "world")) - require.NoError(t, store.EnsureAccountExists(context.Background(), "alice")) - require.NoError(t, store.EnsureAccountExists(context.Background(), "bob")) - meta := core.Metadata{ - "roles": "admin", - "accountId": float64(3), - "enabled": "true", - "a": map[string]any{ - "nested": map[string]any{ - "key": "hello", - }, - }, - } - require.NoError(t, store.UpdateAccountMetadata(context.Background(), "bob", meta)) - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "world": { - "USD": core.NewEmptyVolumes().WithOutput(big.NewInt(250)), + t.Parallel() + + type testCase struct { + name string + queryParams url.Values + expectQuery storage.AccountsQuery + expectStatusCode int + expectedErrorCode string + } + + testCases := []testCase{ + { + name: "nominal", + expectQuery: *storage.NewAccountsQuery(). + WithBalanceOperatorFilter("gte"), + }, + { + name: "using metadata", + queryParams: url.Values{ + "metadata[roles]": []string{"admin"}, }, - "alice": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(150)), + expectQuery: *storage.NewAccountsQuery(). + WithBalanceOperatorFilter("gte"). + WithMetadataFilter(map[string]string{ + "roles": "admin", + }), + }, + { + name: "using nested metadata", + queryParams: url.Values{ + "metadata[a.nested.key]": []string{"hello"}, }, - "bob": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(100)), + expectQuery: *storage.NewAccountsQuery(). + WithBalanceOperatorFilter("gte"). + WithMetadataFilter(map[string]string{ + "a.nested.key": "hello", + }), + }, + { + name: "using after", + queryParams: url.Values{ + "after": []string{"foo"}, }, - })) - - rsp := internal.CountAccounts(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - require.Equal(t, "3", rsp.Header().Get("Count")) - - t.Run("all", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - // 3 accounts: world, bob, alice - require.Len(t, cursor.Data, 3) - require.Equal(t, []core.Account{ - {Address: "world", Metadata: core.Metadata{}}, - {Address: "bob", Metadata: meta}, - {Address: "alice", Metadata: core.Metadata{}}, - }, cursor.Data) - }) - - t.Run("meta roles", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "metadata[roles]": []string{"admin"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - // 1 accounts: bob - require.Len(t, cursor.Data, 1) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - }) - - t.Run("meta accountId", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "metadata[accountId]": []string{"3"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - // 1 accounts: bob - require.Len(t, cursor.Data, 1) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - }) - - t.Run("meta enabled", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "metadata[enabled]": []string{"true"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - // 1 accounts: bob - require.Len(t, cursor.Data, 1) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - }) - - t.Run("meta nested", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "metadata[a.nested.key]": []string{"hello"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - // 1 accounts: bob - require.Len(t, cursor.Data, 1) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - }) - - t.Run("meta unknown", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "metadata[unknown]": []string{"key"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 0) - }) - - t.Run("after", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "after": []string{"bob"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - // 1 accounts: alice - require.Len(t, cursor.Data, 1) - require.Equal(t, "alice", string(cursor.Data[0].Address)) - }) - - t.Run("address", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "address": []string{"b.b"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - // 1 accounts: bob - require.Len(t, cursor.Data, 1) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - }) - - to := ledgerstore.AccountsPaginationToken{} - raw, err := json.Marshal(to) - require.NoError(t, err) - - t.Run(fmt.Sprintf("valid empty %s", controllers.QueryKeyCursor), func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - }) - - t.Run(fmt.Sprintf("valid empty %s with any other param is forbidden", controllers.QueryKeyCursor), func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - "after": []string{"bob"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("no other query params can be set with '%s'", controllers.QueryKeyCursor), - }, err) - }) - - t.Run(fmt.Sprintf("invalid %s", controllers.QueryKeyCursor), func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - controllers.QueryKeyCursor: []string{"invalid"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("invalid '%s' query param", controllers.QueryKeyCursor), - }, err) - }) - - t.Run(fmt.Sprintf("invalid %s not base64", controllers.QueryKeyCursor), func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - controllers.QueryKeyCursor: []string{"\n*@"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("invalid '%s' query param", controllers.QueryKeyCursor), - }, err) - }) - - t.Run("filter by balance >= 50 with default operator", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ + expectQuery: *storage.NewAccountsQuery(). + WithBalanceOperatorFilter("gte"). + WithAfterAddress("foo"). + WithMetadataFilter(map[string]string{}), + }, + { + name: "using balance with default operator", + queryParams: url.Values{ "balance": []string{"50"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 2) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - require.Equal(t, "alice", string(cursor.Data[1].Address)) - }) - - t.Run("filter by balance >= 120 with default operator", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"120"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, "alice", string(cursor.Data[0].Address)) - }) - - t.Run("filter by balance >= 50", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"50"}, - controllers.QueryKeyBalanceOperator: []string{"gte"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 2) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - require.Equal(t, "alice", string(cursor.Data[1].Address)) - }) - - t.Run("filter by balance >= 120", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"120"}, - controllers.QueryKeyBalanceOperator: []string{"gte"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, "alice", string(cursor.Data[0].Address)) - }) - - t.Run("filter by balance > 120", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"120"}, - controllers.QueryKeyBalanceOperator: []string{"gt"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, "alice", string(cursor.Data[0].Address)) - }) - - t.Run("filter by balance < 0", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"0"}, - controllers.QueryKeyBalanceOperator: []string{"lt"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, "world", string(cursor.Data[0].Address)) - }) - - t.Run("filter by balance < 100", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"100"}, - controllers.QueryKeyBalanceOperator: []string{"lt"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, "world", string(cursor.Data[0].Address)) - }) - - t.Run("filter by balance <= 100", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"100"}, - controllers.QueryKeyBalanceOperator: []string{"lte"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 2) - require.Equal(t, "world", string(cursor.Data[0].Address)) - require.Equal(t, "bob", string(cursor.Data[1].Address)) - }) - - t.Run("filter by balance = 100", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"100"}, - controllers.QueryKeyBalanceOperator: []string{"e"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, "bob", string(cursor.Data[0].Address)) - }) - - // test filter by balance != 100 - t.Run("filter by balance != 100", func(t *testing.T) { - rsp = internal.GetAccounts(api, url.Values{ - "balance": []string{"100"}, - controllers.QueryKeyBalanceOperator: []string{"ne"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, 2) - require.Equal(t, "world", string(cursor.Data[0].Address)) - require.Equal(t, "alice", string(cursor.Data[1].Address)) - }) - - t.Run("invalid balance", func(t *testing.T) { - rsp := internal.GetAccounts(api, url.Values{ - "balance": []string{"toto"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid parameter 'balance', should be a number", - }, err) - }) - - t.Run("invalid balance operator", func(t *testing.T) { - rsp := internal.GetAccounts(api, url.Values{ - "balance": []string{"100"}, - controllers.QueryKeyBalanceOperator: []string{"toto"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidBalanceOperator.Error(), - }, err) - }) - }) -} - -func TestGetAccountsWithPageSize(t *testing.T) { - internal.RunTest(t, func(api chi.Router, driver storage.Driver) { - store := internal.GetLedgerStore(t, driver, context.Background()) - - _, err := store.Initialize(context.Background()) - require.NoError(t, err) - - for i := 0; i < 3*controllers.MaxPageSize; i++ { - require.NoError(t, store.UpdateAccountMetadata(context.Background(), fmt.Sprintf("accounts:%06d", i), core.Metadata{ - "foo": []byte("{}"), - })) - } - - t.Run("invalid page size", func(t *testing.T) { - rsp := internal.GetAccounts(api, url.Values{ - controllers.QueryKeyPageSize: []string{"nan"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidPageSize.Error(), - }, err) - }) - t.Run("page size over maximum", func(t *testing.T) { - httpResponse := internal.GetAccounts(api, url.Values{ - controllers.QueryKeyPageSize: []string{fmt.Sprintf("%d", 2*controllers.MaxPageSize)}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor := internal.DecodeCursorResponse[core.Account](t, httpResponse.Body) - require.Len(t, cursor.Data, controllers.MaxPageSize) - require.Equal(t, cursor.PageSize, controllers.MaxPageSize) - require.NotEmpty(t, cursor.Next) - require.True(t, cursor.HasMore) - }) - t.Run("with page size greater than max count", func(t *testing.T) { - httpResponse := internal.GetAccounts(api, url.Values{ - controllers.QueryKeyPageSize: []string{fmt.Sprintf("%d", controllers.MaxPageSize)}, - "after": []string{fmt.Sprintf("accounts:%06d", controllers.MaxPageSize-100)}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor := internal.DecodeCursorResponse[core.Account](t, httpResponse.Body) - require.Len(t, cursor.Data, controllers.MaxPageSize-100) - require.Equal(t, controllers.MaxPageSize, cursor.PageSize) - require.Empty(t, cursor.Next) - require.False(t, cursor.HasMore) - }) - t.Run("with page size lower than max count", func(t *testing.T) { - httpResponse := internal.GetAccounts(api, url.Values{ - controllers.QueryKeyPageSize: []string{fmt.Sprintf("%d", controllers.MaxPageSize/10)}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor := internal.DecodeCursorResponse[core.Account](t, httpResponse.Body) - require.Len(t, cursor.Data, controllers.MaxPageSize/10) - require.Equal(t, cursor.PageSize, controllers.MaxPageSize/10) - require.NotEmpty(t, cursor.Next) - require.True(t, cursor.HasMore) - }) - }) -} - -func TestGetAccount(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - - require.NoError(t, store.UpdateAccountMetadata(context.Background(), "alice", core.Metadata{ - "foo": json.RawMessage(`"bar"`), - })) - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "alice": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(100)), }, - })) - - t.Run("valid address", func(t *testing.T) { - rsp := internal.GetAccount(api, "alice") - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - resp, _ := internal.DecodeSingleResponse[core.AccountWithVolumes](t, rsp.Body) - - require.EqualValues(t, core.AccountWithVolumes{ - Account: core.Account{ - Address: "alice", - Metadata: core.Metadata{ - "foo": "bar", - }, - }, - Volumes: core.AssetsVolumes{ - "USD": { - Input: big.NewInt(100), - Output: big.NewInt(0), + expectQuery: *storage.NewAccountsQuery(). + WithBalanceOperatorFilter("gte"). + WithBalanceFilter("50"). + WithMetadataFilter(map[string]string{}), + }, + { + name: "using balance with specified operator", + queryParams: url.Values{ + "balance": []string{"50"}, + "balanceOperator": []string{"gt"}, + }, + expectQuery: *storage.NewAccountsQuery(). + WithBalanceOperatorFilter("gt"). + WithBalanceFilter("50"). + WithMetadataFilter(map[string]string{}), + }, + { + name: "using invalid balance", + queryParams: url.Values{ + "balance": []string{"xxx"}, + }, + expectedErrorCode: apierrors.ErrValidation, + expectStatusCode: http.StatusBadRequest, + }, + { + name: "using balance with invalid operator", + queryParams: url.Values{ + "balance": []string{"50"}, + "balanceOperator": []string{"xxx"}, + }, + expectedErrorCode: apierrors.ErrValidation, + expectStatusCode: http.StatusBadRequest, + }, + { + name: "using address", + queryParams: url.Values{ + "address": []string{"foo"}, + }, + expectQuery: *storage.NewAccountsQuery(). + WithBalanceOperatorFilter("gte"). + WithAddressFilter("foo"). + WithMetadataFilter(map[string]string{}), + }, + { + name: "using empty cursor", + queryParams: url.Values{ + "cursor": []string{ledgerstore.AccountsPaginationToken{}.Encode()}, + }, + expectQuery: *storage.NewAccountsQuery(). + WithMetadataFilter(nil), + }, + { + name: "using cursor with other param", + queryParams: url.Values{ + "cursor": []string{ledgerstore.AccountsPaginationToken{}.Encode()}, + "after": []string{"foo"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using invalid cursor", + queryParams: url.Values{ + "cursor": []string{"XXX"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "invalid page size", + queryParams: url.Values{ + "pageSize": []string{"nan"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "page size over maximum", + queryParams: url.Values{ + "pageSize": []string{"1000000"}, + }, + expectQuery: *storage.NewAccountsQuery(). + WithPageSize(controllers.MaxPageSize). + WithMetadataFilter(map[string]string{}). + WithBalanceOperatorFilter("gte"), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + + if testCase.expectStatusCode == 0 { + testCase.expectStatusCode = http.StatusOK + } + + expectedCursor := sharedapi.Cursor[core.Account]{ + Data: []core.Account{ + { + Address: "world", + Metadata: map[string]any{}, }, }, - }, resp) - }) - - t.Run("unknown address", func(t *testing.T) { - rsp := internal.GetAccount(api, "bob") - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - resp, _ := internal.DecodeSingleResponse[core.AccountWithVolumes](t, rsp.Body) - require.EqualValues(t, core.AccountWithVolumes{ - Account: core.Account{ - Address: "bob", - Metadata: core.Metadata{}, - }, - Volumes: core.AssetsVolumes{}, - }, resp) - }) - - t.Run("invalid address format", func(t *testing.T) { - rsp := internal.GetAccount(api, "accounts::alice") - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid account address format", - }, err) - }) - }) + } + + backend, mockLedger := newTestingBackend(t) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + mockLedger.EXPECT(). + GetAccounts(gomock.Any(), testCase.expectQuery). + Return(expectedCursor, nil) + } + + router := routes.NewRouter(backend, nil, nil) + + req := httptest.NewRequest(http.MethodGet, "/xxx/accounts", nil) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() + + router.ServeHTTP(rec, req) + + require.Equal(t, testCase.expectStatusCode, rec.Code) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + cursor := DecodeCursorResponse[core.Account](t, rec.Body) + require.Equal(t, expectedCursor, *cursor) + } else { + err := sharedapi.ErrorResponse{} + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } + }) + } } -func TestPostAccountMetadata(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) +func TestGetAccount(t *testing.T) { + t.Parallel() - _, err = store.Initialize(context.Background()) - require.NoError(t, err) + account := core.AccountWithVolumes{ + Account: core.Account{ + Address: "foo", + Metadata: map[string]any{}, + }, + Volumes: map[string]core.Volumes{}, + } - require.NoError(t, store.EnsureAccountExists(context.Background(), "alice")) + backend, mock := newTestingBackend(t) + mock.EXPECT(). + GetAccount(gomock.Any(), "foo"). + Return(&account, nil) - t.Run("valid request", func(t *testing.T) { - rsp := internal.PostAccountMetadata(t, api, "alice", - core.Metadata{ - "foo": json.RawMessage(`"bar"`), - }) - require.Equal(t, http.StatusNoContent, rsp.Result().StatusCode, rsp.Body.String()) - }) + router := routes.NewRouter(backend, nil, nil) - t.Run("unknown account should succeed", func(t *testing.T) { - rsp := internal.PostAccountMetadata(t, api, "bob", - core.Metadata{ - "foo": json.RawMessage(`"bar"`), - }) - require.Equal(t, http.StatusNoContent, rsp.Result().StatusCode, rsp.Body.String()) - }) + req := httptest.NewRequest(http.MethodGet, "/xxx/accounts/foo", nil) + rec := httptest.NewRecorder() - t.Run("invalid address format", func(t *testing.T) { - rsp := internal.PostAccountMetadata(t, api, "accounts::alice", core.Metadata{}) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) + router.ServeHTTP(rec, req) - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid account address format", - }, err) - }) - - t.Run("invalid metadata format", func(t *testing.T) { - rsp := internal.NewRequestOnLedger(t, api, "/accounts/alice/metadata", "invalid") - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) + require.Equal(t, http.StatusOK, rec.Code) + response, _ := DecodeSingleResponse[core.AccountWithVolumes](t, rec.Body) + require.Equal(t, account, response) +} - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid metadata format", - }, err) - }) - }) +func TestPostAccountMetadata(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + queryParams url.Values + expectStatusCode int + expectedErrorCode string + account string + body any + } + + testCases := []testCase{ + { + name: "nominal", + account: "world", + body: core.Metadata{ + "foo": "bar", + }, + }, + { + name: "invalid account address format", + account: "invalid-acc", + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "invalid body", + account: "world", + body: "invalid - not an object", + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + + if testCase.expectStatusCode == 0 { + testCase.expectStatusCode = http.StatusNoContent + } + + backend, mock := newTestingBackend(t) + if testCase.expectStatusCode == http.StatusNoContent { + mock.EXPECT(). + SaveMeta(gomock.Any(), core.MetaTargetTypeAccount, testCase.account, testCase.body). + Return(nil) + } + + router := routes.NewRouter(backend, nil, nil) + + req := httptest.NewRequest(http.MethodPost, "/xxx/accounts/"+testCase.account+"/metadata", Buffer(t, testCase.body)) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() + + router.ServeHTTP(rec, req) + + require.Equal(t, testCase.expectStatusCode, rec.Code) + if testCase.expectStatusCode >= 300 || testCase.expectStatusCode < 200 { + err := sharedapi.ErrorResponse{} + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } + }) + } } diff --git a/pkg/api/controllers/api.go b/pkg/api/controllers/api.go new file mode 100644 index 000000000..54055419d --- /dev/null +++ b/pkg/api/controllers/api.go @@ -0,0 +1,63 @@ +package controllers + +import ( + "context" + + "github.com/formancehq/ledger/pkg/core" + "github.com/formancehq/ledger/pkg/ledger" + "github.com/formancehq/ledger/pkg/storage" + "github.com/formancehq/stack/libs/go-libs/api" +) + +//go:generate mockgen -source api.go -destination api_test.go -package controllers_test . Ledger + +type Ledger interface { + GetAccount(ctx context.Context, param string) (*core.AccountWithVolumes, error) + SaveMeta(ctx context.Context, targetType string, targetID any, m core.Metadata) error + GetAccounts(ctx context.Context, query storage.AccountsQuery) (api.Cursor[core.Account], error) + CountAccounts(ctx context.Context, query storage.AccountsQuery) (uint64, error) + GetBalancesAggregated(ctx context.Context, q storage.BalancesQuery) (core.AssetsBalances, error) + GetBalances(ctx context.Context, q storage.BalancesQuery) (api.Cursor[core.AccountsBalances], error) + GetMigrationsInfo(ctx context.Context) ([]core.MigrationInfo, error) + Stats(ctx context.Context) (ledger.Stats, error) + GetLogs(ctx context.Context, query storage.LogsQuery) (api.Cursor[core.Log], error) + CountTransactions(ctx context.Context, query storage.TransactionsQuery) (uint64, error) + GetTransactions(ctx context.Context, query storage.TransactionsQuery) (api.Cursor[core.ExpandedTransaction], error) + CreateTransaction(ctx context.Context, preview bool, data core.RunScript) (*core.ExpandedTransaction, error) + GetTransaction(ctx context.Context, id uint64) (*core.ExpandedTransaction, error) + RevertTransaction(ctx context.Context, id uint64) (*core.ExpandedTransaction, error) +} + +type Backend interface { + GetLedger(ctx context.Context, name string) (Ledger, error) + ListLedgers(ctx context.Context) ([]string, error) + GetVersion() string +} + +type DefaultBackend struct { + storageDriver storage.Driver + resolver *ledger.Resolver + version string +} + +func (d DefaultBackend) GetLedger(ctx context.Context, name string) (Ledger, error) { + return d.resolver.GetLedger(ctx, name) +} + +func (d DefaultBackend) ListLedgers(ctx context.Context) ([]string, error) { + return d.storageDriver.GetSystemStore().ListLedgers(ctx) +} + +func (d DefaultBackend) GetVersion() string { + return d.version +} + +var _ Backend = (*DefaultBackend)(nil) + +func NewDefaultBackend(driver storage.Driver, version string, resolver *ledger.Resolver) *DefaultBackend { + return &DefaultBackend{ + storageDriver: driver, + resolver: resolver, + version: version, + } +} diff --git a/pkg/api/controllers/api_test.go b/pkg/api/controllers/api_test.go new file mode 100644 index 000000000..98b5e53d9 --- /dev/null +++ b/pkg/api/controllers/api_test.go @@ -0,0 +1,316 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: api.go + +// Package controllers_test is a generated GoMock package. +package controllers_test + +import ( + context "context" + reflect "reflect" + + controllers "github.com/formancehq/ledger/pkg/api/controllers" + core "github.com/formancehq/ledger/pkg/core" + ledger "github.com/formancehq/ledger/pkg/ledger" + storage "github.com/formancehq/ledger/pkg/storage" + api "github.com/formancehq/stack/libs/go-libs/api" + gomock "github.com/golang/mock/gomock" +) + +// MockLedger is a mock of Ledger interface. +type MockLedger struct { + ctrl *gomock.Controller + recorder *MockLedgerMockRecorder +} + +// MockLedgerMockRecorder is the mock recorder for MockLedger. +type MockLedgerMockRecorder struct { + mock *MockLedger +} + +// NewMockLedger creates a new mock instance. +func NewMockLedger(ctrl *gomock.Controller) *MockLedger { + mock := &MockLedger{ctrl: ctrl} + mock.recorder = &MockLedgerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLedger) EXPECT() *MockLedgerMockRecorder { + return m.recorder +} + +// CountAccounts mocks base method. +func (m *MockLedger) CountAccounts(ctx context.Context, query storage.AccountsQuery) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAccounts", ctx, query) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAccounts indicates an expected call of CountAccounts. +func (mr *MockLedgerMockRecorder) CountAccounts(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccounts", reflect.TypeOf((*MockLedger)(nil).CountAccounts), ctx, query) +} + +// CountTransactions mocks base method. +func (m *MockLedger) CountTransactions(ctx context.Context, query storage.TransactionsQuery) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountTransactions", ctx, query) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountTransactions indicates an expected call of CountTransactions. +func (mr *MockLedgerMockRecorder) CountTransactions(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountTransactions", reflect.TypeOf((*MockLedger)(nil).CountTransactions), ctx, query) +} + +// CreateTransaction mocks base method. +func (m *MockLedger) CreateTransaction(ctx context.Context, preview bool, data core.RunScript) (*core.ExpandedTransaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateTransaction", ctx, preview, data) + ret0, _ := ret[0].(*core.ExpandedTransaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateTransaction indicates an expected call of CreateTransaction. +func (mr *MockLedgerMockRecorder) CreateTransaction(ctx, preview, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTransaction", reflect.TypeOf((*MockLedger)(nil).CreateTransaction), ctx, preview, data) +} + +// GetAccount mocks base method. +func (m *MockLedger) GetAccount(ctx context.Context, param string) (*core.AccountWithVolumes, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccount", ctx, param) + ret0, _ := ret[0].(*core.AccountWithVolumes) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccount indicates an expected call of GetAccount. +func (mr *MockLedgerMockRecorder) GetAccount(ctx, param interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccount", reflect.TypeOf((*MockLedger)(nil).GetAccount), ctx, param) +} + +// GetAccounts mocks base method. +func (m *MockLedger) GetAccounts(ctx context.Context, query storage.AccountsQuery) (api.Cursor[core.Account], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccounts", ctx, query) + ret0, _ := ret[0].(api.Cursor[core.Account]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccounts indicates an expected call of GetAccounts. +func (mr *MockLedgerMockRecorder) GetAccounts(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccounts", reflect.TypeOf((*MockLedger)(nil).GetAccounts), ctx, query) +} + +// GetBalances mocks base method. +func (m *MockLedger) GetBalances(ctx context.Context, q storage.BalancesQuery) (api.Cursor[core.AccountsBalances], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBalances", ctx, q) + ret0, _ := ret[0].(api.Cursor[core.AccountsBalances]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBalances indicates an expected call of GetBalances. +func (mr *MockLedgerMockRecorder) GetBalances(ctx, q interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalances", reflect.TypeOf((*MockLedger)(nil).GetBalances), ctx, q) +} + +// GetBalancesAggregated mocks base method. +func (m *MockLedger) GetBalancesAggregated(ctx context.Context, q storage.BalancesQuery) (core.AssetsBalances, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBalancesAggregated", ctx, q) + ret0, _ := ret[0].(core.AssetsBalances) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBalancesAggregated indicates an expected call of GetBalancesAggregated. +func (mr *MockLedgerMockRecorder) GetBalancesAggregated(ctx, q interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalancesAggregated", reflect.TypeOf((*MockLedger)(nil).GetBalancesAggregated), ctx, q) +} + +// GetLogs mocks base method. +func (m *MockLedger) GetLogs(ctx context.Context, query storage.LogsQuery) (api.Cursor[core.Log], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLogs", ctx, query) + ret0, _ := ret[0].(api.Cursor[core.Log]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLogs indicates an expected call of GetLogs. +func (mr *MockLedgerMockRecorder) GetLogs(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockLedger)(nil).GetLogs), ctx, query) +} + +// GetMigrationsInfo mocks base method. +func (m *MockLedger) GetMigrationsInfo(ctx context.Context) ([]core.MigrationInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMigrationsInfo", ctx) + ret0, _ := ret[0].([]core.MigrationInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMigrationsInfo indicates an expected call of GetMigrationsInfo. +func (mr *MockLedgerMockRecorder) GetMigrationsInfo(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMigrationsInfo", reflect.TypeOf((*MockLedger)(nil).GetMigrationsInfo), ctx) +} + +// GetTransaction mocks base method. +func (m *MockLedger) GetTransaction(ctx context.Context, id uint64) (*core.ExpandedTransaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTransaction", ctx, id) + ret0, _ := ret[0].(*core.ExpandedTransaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTransaction indicates an expected call of GetTransaction. +func (mr *MockLedgerMockRecorder) GetTransaction(ctx, id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTransaction", reflect.TypeOf((*MockLedger)(nil).GetTransaction), ctx, id) +} + +// GetTransactions mocks base method. +func (m *MockLedger) GetTransactions(ctx context.Context, query storage.TransactionsQuery) (api.Cursor[core.ExpandedTransaction], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTransactions", ctx, query) + ret0, _ := ret[0].(api.Cursor[core.ExpandedTransaction]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTransactions indicates an expected call of GetTransactions. +func (mr *MockLedgerMockRecorder) GetTransactions(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTransactions", reflect.TypeOf((*MockLedger)(nil).GetTransactions), ctx, query) +} + +// RevertTransaction mocks base method. +func (m *MockLedger) RevertTransaction(ctx context.Context, id uint64) (*core.ExpandedTransaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevertTransaction", ctx, id) + ret0, _ := ret[0].(*core.ExpandedTransaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RevertTransaction indicates an expected call of RevertTransaction. +func (mr *MockLedgerMockRecorder) RevertTransaction(ctx, id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevertTransaction", reflect.TypeOf((*MockLedger)(nil).RevertTransaction), ctx, id) +} + +// SaveMeta mocks base method. +func (m_2 *MockLedger) SaveMeta(ctx context.Context, targetType string, targetID any, m core.Metadata) error { + m_2.ctrl.T.Helper() + ret := m_2.ctrl.Call(m_2, "SaveMeta", ctx, targetType, targetID, m) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveMeta indicates an expected call of SaveMeta. +func (mr *MockLedgerMockRecorder) SaveMeta(ctx, targetType, targetID, m interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveMeta", reflect.TypeOf((*MockLedger)(nil).SaveMeta), ctx, targetType, targetID, m) +} + +// Stats mocks base method. +func (m *MockLedger) Stats(ctx context.Context) (ledger.Stats, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stats", ctx) + ret0, _ := ret[0].(ledger.Stats) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Stats indicates an expected call of Stats. +func (mr *MockLedgerMockRecorder) Stats(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stats", reflect.TypeOf((*MockLedger)(nil).Stats), ctx) +} + +// MockBackend is a mock of Backend interface. +type MockBackend struct { + ctrl *gomock.Controller + recorder *MockBackendMockRecorder +} + +// MockBackendMockRecorder is the mock recorder for MockBackend. +type MockBackendMockRecorder struct { + mock *MockBackend +} + +// NewMockBackend creates a new mock instance. +func NewMockBackend(ctrl *gomock.Controller) *MockBackend { + mock := &MockBackend{ctrl: ctrl} + mock.recorder = &MockBackendMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBackend) EXPECT() *MockBackendMockRecorder { + return m.recorder +} + +// GetLedger mocks base method. +func (m *MockBackend) GetLedger(ctx context.Context, name string) (controllers.Ledger, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLedger", ctx, name) + ret0, _ := ret[0].(controllers.Ledger) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLedger indicates an expected call of GetLedger. +func (mr *MockBackendMockRecorder) GetLedger(ctx, name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLedger", reflect.TypeOf((*MockBackend)(nil).GetLedger), ctx, name) +} + +// GetVersion mocks base method. +func (m *MockBackend) GetVersion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetVersion") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetVersion indicates an expected call of GetVersion. +func (mr *MockBackendMockRecorder) GetVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockBackend)(nil).GetVersion)) +} + +// ListLedgers mocks base method. +func (m *MockBackend) ListLedgers(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListLedgers", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListLedgers indicates an expected call of ListLedgers. +func (mr *MockBackendMockRecorder) ListLedgers(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListLedgers", reflect.TypeOf((*MockBackend)(nil).ListLedgers), ctx) +} diff --git a/pkg/api/controllers/balance_controller_test.go b/pkg/api/controllers/balance_controller_test.go index edf6b29e8..d0f58d8ea 100644 --- a/pkg/api/controllers/balance_controller_test.go +++ b/pkg/api/controllers/balance_controller_test.go @@ -1,164 +1,162 @@ package controllers_test import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" "math/big" "net/http" + "net/http/httptest" "net/url" "testing" - "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/api/internal" + "github.com/formancehq/ledger/pkg/api/apierrors" + "github.com/formancehq/ledger/pkg/api/routes" "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/storage" - ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" - "github.com/go-chi/chi/v5" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" + sharedapi "github.com/formancehq/stack/libs/go-libs/api" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) func TestGetBalancesAggregated(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "world": { - "USD": core.NewEmptyVolumes().WithOutput(big.NewInt(250)), - }, - "alice": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(150)), - }, - "bob": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(100)), + t.Parallel() + + type testCase struct { + name string + queryParams url.Values + expectQuery storage.BalancesQuery + } + + testCases := []testCase{ + { + name: "nominal", + expectQuery: *storage.NewBalancesQuery(), + }, + { + name: "using address", + queryParams: url.Values{ + "address": []string{"foo"}, }, - })) - - t.Run("all", func(t *testing.T) { - rsp := internal.GetBalancesAggregated(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - resp, ok := internal.DecodeSingleResponse[core.AssetsBalances](t, rsp.Body) - require.Equal(t, ok, true) - require.Equal(t, core.AssetsBalances{"USD": big.NewInt(0)}, resp) - }) - - t.Run("filter by address", func(t *testing.T) { - rsp := internal.GetBalancesAggregated(api, url.Values{"address": []string{"world"}}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - resp, ok := internal.DecodeSingleResponse[core.AssetsBalances](t, rsp.Body) - require.Equal(t, true, ok) - require.Equal(t, core.AssetsBalances{"USD": big.NewInt(-250)}, resp) - }) - - t.Run("filter by address no result", func(t *testing.T) { - rsp := internal.GetBalancesAggregated(api, url.Values{"address": []string{"XXX"}}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - resp, ok := internal.DecodeSingleResponse[core.AssetsBalances](t, rsp.Body) - require.Equal(t, ok, true) - require.Equal(t, core.AssetsBalances{}, resp) + expectQuery: *storage.NewBalancesQuery().WithAddressFilter("foo"), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + + expectedBalances := core.AssetsBalances{ + "world": big.NewInt(-100), + } + backend, mock := newTestingBackend(t) + mock.EXPECT(). + GetBalancesAggregated(gomock.Any(), testCase.expectQuery). + Return(expectedBalances, nil) + + router := routes.NewRouter(backend, nil, nil) + + req := httptest.NewRequest(http.MethodGet, "/xxx/aggregate/balances", nil) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() + + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + balances, ok := DecodeSingleResponse[core.AssetsBalances](t, rec.Body) + require.True(t, ok) + require.Equal(t, expectedBalances, balances) }) - }) + } } func TestGetBalances(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "world": { - "USD": core.NewEmptyVolumes().WithOutput(big.NewInt(250)), - "CAD": core.NewEmptyVolumes().WithOutput(big.NewInt(200)), - "EUR": core.NewEmptyVolumes().WithOutput(big.NewInt(400)), + t.Parallel() + + type testCase struct { + name string + queryParams url.Values + expectQuery storage.BalancesQuery + expectStatusCode int + expectedErrorCode string + } + + testCases := []testCase{ + { + name: "nominal", + expectQuery: *storage.NewBalancesQuery(), + }, + { + name: "empty cursor with other param", + queryParams: url.Values{ + "cursor": []string{ledger.BalancesPaginationToken{}.Encode()}, + "after": []string{"bob"}, }, - "alice": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(150)), - "CAD": core.NewEmptyVolumes().WithInput(big.NewInt(200)), - "EUR": core.NewEmptyVolumes().WithInput(big.NewInt(400)), + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "invalid cursor", + queryParams: url.Values{ + "cursor": []string{"xxx"}, }, - "bob": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(100)), + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using after", + queryParams: url.Values{ + "after": []string{"foo"}, }, - })) - - to := ledgerstore.BalancesPaginationToken{} - raw, err := json.Marshal(to) - require.NoError(t, err) - - t.Run("valid empty "+controllers.QueryKeyCursor, func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - }) - - t.Run(fmt.Sprintf("valid empty %s with any other param is forbidden", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - "after": []string{"bob"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - }) - - t.Run(fmt.Sprintf("invalid %s", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{ - controllers.QueryKeyCursor: []string{"invalid"}, - }) - - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - require.Contains(t, rsp.Body.String(), - fmt.Sprintf(`"invalid '%s' query param"`, controllers.QueryKeyCursor)) - }) - - t.Run("all", func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - resp := internal.DecodeCursorResponse[core.AccountsBalances](t, rsp.Body) - require.Equal(t, []core.AccountsBalances{ - {"world": core.AssetsBalances{"USD": big.NewInt(-250), "EUR": big.NewInt(-400), "CAD": big.NewInt(-200)}}, - {"bob": core.AssetsBalances{"USD": big.NewInt(100)}}, - {"alice": core.AssetsBalances{"USD": big.NewInt(150), "EUR": big.NewInt(400), "CAD": big.NewInt(200)}}, - }, resp.Data) - }) - - t.Run("after address", func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{"after": []string{"bob"}}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - resp := internal.DecodeCursorResponse[core.AccountsBalances](t, rsp.Body) - require.Equal(t, []core.AccountsBalances{ - {"alice": core.AssetsBalances{"USD": big.NewInt(150), "EUR": big.NewInt(400), "CAD": big.NewInt(200)}}, - }, resp.Data) - }) - - t.Run("filter by address", func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{"address": []string{"world"}}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - resp := internal.DecodeCursorResponse[core.AccountsBalances](t, rsp.Body) - require.Equal(t, []core.AccountsBalances{ - {"world": core.AssetsBalances{"USD": big.NewInt(-250), "EUR": big.NewInt(-400), "CAD": big.NewInt(-200)}}, - }, resp.Data) - }) - - t.Run("filter by address no results", func(t *testing.T) { - rsp := internal.GetBalances(api, url.Values{"address": []string{"TEST"}}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - resp := internal.DecodeCursorResponse[core.AccountsBalances](t, rsp.Body) - require.Equal(t, []core.AccountsBalances{}, resp.Data) + expectQuery: *storage.NewBalancesQuery().WithAfterAddress("foo"), + }, + { + name: "using address", + queryParams: url.Values{ + "address": []string{"foo"}, + }, + expectQuery: *storage.NewBalancesQuery().WithAddressFilter("foo"), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + + if testCase.expectStatusCode == 0 { + testCase.expectStatusCode = http.StatusOK + } + + expectedCursor := sharedapi.Cursor[core.AccountsBalances]{ + Data: []core.AccountsBalances{ + { + "world": core.AssetsBalances{ + "USD": big.NewInt(100), + }, + }, + }, + } + + backend, mock := newTestingBackend(t) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + mock.EXPECT(). + GetBalances(gomock.Any(), testCase.expectQuery). + Return(expectedCursor, nil) + } + + router := routes.NewRouter(backend, nil, nil) + + req := httptest.NewRequest(http.MethodGet, "/xxx/balances", nil) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() + + router.ServeHTTP(rec, req) + + require.Equal(t, testCase.expectStatusCode, rec.Code) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + cursor := DecodeCursorResponse[core.AccountsBalances](t, rec.Body) + require.Equal(t, expectedCursor, *cursor) + } else { + err := sharedapi.ErrorResponse{} + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } }) - }) + } } diff --git a/pkg/api/controllers/config_controller.go b/pkg/api/controllers/config_controller.go index e2b89f97b..1ba68f7fa 100644 --- a/pkg/api/controllers/config_controller.go +++ b/pkg/api/controllers/config_controller.go @@ -4,14 +4,13 @@ import ( _ "embed" "net/http" - "github.com/formancehq/ledger/pkg/storage" sharedapi "github.com/formancehq/stack/libs/go-libs/api" ) type ConfigInfo struct { - Server string `json:"server"` - Version interface{} `json:"version"` - Config *Config `json:"config"` + Server string `json:"server"` + Version string `json:"version"` + Config *Config `json:"config"` } type Config struct { @@ -23,19 +22,19 @@ type LedgerStorage struct { Ledgers []string `json:"ledgers"` } -func GetInfo(storageDriver storage.Driver, version string) func(w http.ResponseWriter, r *http.Request) { +func GetInfo(backend Backend) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - ledgers, err := storageDriver.GetSystemStore().ListLedgers(r.Context()) + ledgers, err := backend.ListLedgers(r.Context()) if err != nil { panic(err) } sharedapi.RawOk(w, ConfigInfo{ Server: "ledger", - Version: version, + Version: backend.GetVersion(), Config: &Config{ LedgerStorage: &LedgerStorage{ - Driver: storageDriver.Name(), + Driver: "postgres", Ledgers: ledgers, }, }, diff --git a/pkg/api/controllers/config_controller_test.go b/pkg/api/controllers/config_controller_test.go index 3a6643406..e1e1fcb7a 100644 --- a/pkg/api/controllers/config_controller_test.go +++ b/pkg/api/controllers/config_controller_test.go @@ -3,33 +3,49 @@ package controllers_test import ( "encoding/json" "net/http" + "net/http/httptest" "testing" "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/api/internal" - "github.com/formancehq/ledger/pkg/storage" - "github.com/go-chi/chi/v5" + "github.com/formancehq/ledger/pkg/api/routes" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) func TestGetInfo(t *testing.T) { - internal.RunTest(t, func(h chi.Router, driver storage.Driver) { - rsp := internal.GetInfo(h) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - info := controllers.ConfigInfo{} - require.NoError(t, json.NewDecoder(rsp.Body).Decode(&info)) - - info.Config.LedgerStorage.Ledgers = []string{} - require.EqualValues(t, controllers.ConfigInfo{ - Server: "ledger", - Version: "latest", - Config: &controllers.Config{ - LedgerStorage: &controllers.LedgerStorage{ - Driver: driver.Name(), - Ledgers: []string{}, - }, + t.Parallel() + + backend, _ := newTestingBackend(t) + router := routes.NewRouter(backend, nil, nil) + + backend. + EXPECT(). + ListLedgers(gomock.Any()). + Return([]string{"a", "b"}, nil) + + backend. + EXPECT(). + GetVersion(). + Return("latest") + + req := httptest.NewRequest(http.MethodGet, "/_info", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + info := controllers.ConfigInfo{} + require.NoError(t, json.NewDecoder(rec.Body).Decode(&info)) + + require.EqualValues(t, controllers.ConfigInfo{ + Server: "ledger", + Version: "latest", + Config: &controllers.Config{ + LedgerStorage: &controllers.LedgerStorage{ + Driver: "postgres", + Ledgers: []string{"a", "b"}, }, - }, info) - }) + }, + }, info) } diff --git a/pkg/api/controllers/context.go b/pkg/api/controllers/context.go index 0bf3c01ef..bdf8d04c0 100644 --- a/pkg/api/controllers/context.go +++ b/pkg/api/controllers/context.go @@ -2,18 +2,16 @@ package controllers import ( "context" - - "github.com/formancehq/ledger/pkg/ledger" ) type ledgerKey struct{} var _ledgerKey = ledgerKey{} -func ContextWithLedger(ctx context.Context, ledger *ledger.Ledger) context.Context { +func ContextWithLedger(ctx context.Context, ledger Ledger) context.Context { return context.WithValue(ctx, _ledgerKey, ledger) } -func LedgerFromContext(ctx context.Context) *ledger.Ledger { - return ctx.Value(_ledgerKey).(*ledger.Ledger) +func LedgerFromContext(ctx context.Context) Ledger { + return ctx.Value(_ledgerKey).(Ledger) } diff --git a/pkg/api/controllers/ledger_controller.go b/pkg/api/controllers/ledger_controller.go index 4ced1df75..3bcb5bd7f 100644 --- a/pkg/api/controllers/ledger_controller.go +++ b/pkg/api/controllers/ledger_controller.go @@ -19,10 +19,10 @@ import ( type Info struct { Name string `json:"name"` - Storage storageInfo `json:"storage"` + Storage StorageInfo `json:"storage"` } -type storageInfo struct { +type StorageInfo struct { Migrations []core.MigrationInfo `json:"migrations"` } @@ -32,7 +32,7 @@ func GetLedgerInfo(w http.ResponseWriter, r *http.Request) { var err error res := Info{ Name: chi.URLParam(r, "ledger"), - Storage: storageInfo{}, + Storage: StorageInfo{}, } res.Storage.Migrations, err = ledger.GetMigrationsInfo(r.Context()) if err != nil { @@ -132,7 +132,7 @@ func GetLogs(w http.ResponseWriter, r *http.Request) { WithPageSize(pageSize) } - cursor, err := l.GetLogs(r.Context(), logsQuery) + cursor, err := l.GetLogs(r.Context(), *logsQuery) if err != nil { apierrors.ResponseError(w, r, err) return diff --git a/pkg/api/controllers/ledger_controller_test.go b/pkg/api/controllers/ledger_controller_test.go index 499fc5129..06a51052d 100644 --- a/pkg/api/controllers/ledger_controller_test.go +++ b/pkg/api/controllers/ledger_controller_test.go @@ -1,303 +1,225 @@ package controllers_test import ( - "context" - "encoding/base64" "encoding/json" - "fmt" - "math/big" "net/http" + "net/http/httptest" "net/url" "testing" "time" "github.com/formancehq/ledger/pkg/api/apierrors" "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/api/internal" + "github.com/formancehq/ledger/pkg/api/routes" "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/ledger" "github.com/formancehq/ledger/pkg/storage" ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" - "github.com/formancehq/ledger/pkg/storage/sqlstorage/migrations" sharedapi "github.com/formancehq/stack/libs/go-libs/api" - "github.com/go-chi/chi/v5" - "github.com/google/uuid" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) func TestGetLedgerInfo(t *testing.T) { - internal.RunTest(t, func(h chi.Router, driver storage.Driver) { - availableMigrations, err := migrations.CollectMigrationFiles(ledgerstore.MigrationsFS) - require.NoError(t, err) - - rsp := internal.GetLedgerInfo(h) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - info, ok := internal.DecodeSingleResponse[controllers.Info](t, rsp.Body) - require.Equal(t, true, ok) - - _, err = uuid.Parse(info.Name) - require.NoError(t, err) - - require.Equal(t, len(availableMigrations), len(info.Storage.Migrations)) - - for _, m := range info.Storage.Migrations { - require.Equal(t, "DONE", m.State) - require.NotEqual(t, "", m.Name) - require.NotEqual(t, time.Time{}, m.Date) - } - }) + t.Parallel() + + backend, mock := newTestingBackend(t) + router := routes.NewRouter(backend, nil, nil) + + migrationInfo := []core.MigrationInfo{ + { + Version: "1", + Name: "init", + State: "ready", + Date: core.Now().Add(-2 * time.Minute).Round(time.Second), + }, + { + Version: "2", + Name: "fix", + State: "ready", + Date: core.Now().Add(-time.Minute).Round(time.Second), + }, + } + + mock.EXPECT(). + GetMigrationsInfo(gomock.Any()). + Return(migrationInfo, nil) + + req := httptest.NewRequest(http.MethodGet, "/xxx/_info", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + info, ok := DecodeSingleResponse[controllers.Info](t, rec.Body) + require.True(t, ok) + + require.EqualValues(t, controllers.Info{ + Name: "xxx", + Storage: controllers.StorageInfo{ + Migrations: migrationInfo, + }, + }, info) } func TestGetStats(t *testing.T) { - internal.RunTest(t, func(h chi.Router, storageDriver storage.Driver) { - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes( - core.NewTransaction().WithPostings( - core.NewPosting("world", "alice", "USD", big.NewInt(100)), - ), - ))) - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes( - core.NewTransaction(). - WithPostings(core.NewPosting("world", "bob", "USD", big.NewInt(100))). - WithID(1), - ))) - require.NoError(t, store.EnsureAccountExists(context.Background(), "world")) - require.NoError(t, store.EnsureAccountExists(context.Background(), "alice")) - require.NoError(t, store.EnsureAccountExists(context.Background(), "bob")) - - rsp := internal.GetLedgerStats(h) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - stats, _ := internal.DecodeSingleResponse[ledger.Stats](t, rsp.Body) - - require.EqualValues(t, ledger.Stats{ - Transactions: 2, - Accounts: 3, - }, stats) - }) -} + t.Parallel() -func TestGetLogs(t *testing.T) { - internal.RunTest(t, func(api chi.Router, driver storage.Driver) { - now := core.Now() - tx1 := core.ExpandedTransaction{ - Transaction: core.Transaction{ - ID: 0, - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "alice", - Amount: big.NewInt(100), - Asset: "USD", - }, - }, - Timestamp: now.Add(-3 * time.Hour), - }, - }, - } - tx2 := core.ExpandedTransaction{ - Transaction: core.Transaction{ - ID: 1, - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bob", - Amount: big.NewInt(200), - Asset: "USD", - }, - }, - Timestamp: now.Add(-2 * time.Hour), - }, - }, - } - store := internal.GetLedgerStore(t, driver, context.Background()) - _, err := store.Initialize(context.Background()) - require.NoError(t, err) - - require.NoError(t, store.InsertTransactions(context.Background(), tx1, tx2)) - - for _, tx := range []core.ExpandedTransaction{tx1, tx2} { - log := core.NewTransactionLog(tx.Transaction, nil) - require.NoError(t, store.AppendLog(context.Background(), &log)) - } - - at := core.Now() - require.NoError(t, store.UpdateTransactionMetadata(context.Background(), - 0, core.Metadata{"key": "value"})) - - log := core.NewSetMetadataLog(at, core.SetMetadataLogPayload{ - TargetType: core.MetaTargetTypeTransaction, - TargetID: 0, - Metadata: core.Metadata{"key": "value"}, - }) - require.NoError(t, store.AppendLog(context.Background(), &log)) + backend, mock := newTestingBackend(t) + router := routes.NewRouter(backend, nil, nil) - at2 := core.Now() - require.NoError(t, store.UpdateAccountMetadata(context.Background(), "alice", core.Metadata{"key": "value"})) + expectedStats := ledger.Stats{ + Transactions: 10, + Accounts: 5, + } - log2 := core.NewSetMetadataLog(at2, core.SetMetadataLogPayload{ - TargetType: core.MetaTargetTypeAccount, - TargetID: "alice", - Metadata: core.Metadata{"key": "value"}, - }) - require.NoError(t, store.AppendLog(context.Background(), &log2)) - - var log0Timestamp, log1Timestamp core.Time - t.Run("all", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Log](t, rsp.Body) - // all logs - require.Len(t, cursor.Data, 4) - require.Equal(t, uint64(3), cursor.Data[0].ID) - require.Equal(t, uint64(2), cursor.Data[1].ID) - require.Equal(t, uint64(1), cursor.Data[2].ID) - require.Equal(t, uint64(0), cursor.Data[3].ID) - - log0Timestamp = cursor.Data[3].Date - log1Timestamp = cursor.Data[2].Date - }) + mock.EXPECT(). + Stats(gomock.Any()). + Return(expectedStats, nil) - t.Run("after", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - "after": []string{"1"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Log](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, uint64(0), cursor.Data[0].ID) - }) + req := httptest.NewRequest(http.MethodGet, "/xxx/stats", nil) + rec := httptest.NewRecorder() - t.Run("invalid after", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - "after": []string{"invalid"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid 'after' query param", - }, err) - }) + router.ServeHTTP(rec, req) - t.Run("time range", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyStartTime: []string{log0Timestamp.Format(time.RFC3339)}, - controllers.QueryKeyEndTime: []string{log1Timestamp.Format(time.RFC3339)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Log](t, rsp.Body) - require.Len(t, cursor.Data, 1) - require.Equal(t, uint64(0), cursor.Data[0].ID) - }) + require.Equal(t, http.StatusOK, rec.Code) - t.Run("only start time", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyStartTime: []string{core.Now().Add(time.Second).Format(time.RFC3339)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Log](t, rsp.Body) - require.Len(t, cursor.Data, 0) - }) + stats, ok := DecodeSingleResponse[ledger.Stats](t, rec.Body) + require.True(t, ok) - t.Run("only end time", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyEndTime: []string{core.Now().Add(time.Second).Format(time.RFC3339)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.Log](t, rsp.Body) - require.Len(t, cursor.Data, 4) - }) + require.EqualValues(t, expectedStats, stats) +} - t.Run("invalid start time", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyStartTime: []string{"invalid time"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidStartTime.Error(), - }, err) - }) +func TestGetLogs(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + queryParams url.Values + expectQuery storage.LogsQuery + expectStatusCode int + expectedErrorCode string + } + + now := core.Now() + testCases := []testCase{ + { + name: "nominal", + expectQuery: *storage.NewLogsQuery(), + }, + { + name: "using after", + queryParams: url.Values{ + "after": []string{"10"}, + }, + expectQuery: *storage.NewLogsQuery().WithAfterID(10), + }, + { + name: "using invalid after", + queryParams: url.Values{ + "after": []string{"xxx"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using start time", + queryParams: url.Values{ + "startTime": []string{now.Format(core.DateFormat)}, + }, + expectQuery: *storage.NewLogsQuery().WithStartTimeFilter(now), + }, + { + name: "using end time", + queryParams: url.Values{ + "endTime": []string{now.Format(core.DateFormat)}, + }, + expectQuery: *storage.NewLogsQuery().WithEndTimeFilter(now), + }, + { + name: "using invalid start time", + queryParams: url.Values{ + "startTime": []string{"xxx"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using invalid end time", + queryParams: url.Values{ + "endTime": []string{"xxx"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using empty cursor", + queryParams: url.Values{ + "cursor": []string{ledgerstore.LogsPaginationToken{}.Encode()}, + }, + expectQuery: *storage.NewLogsQuery(), + }, + { + name: "using invalid cursor", + queryParams: url.Values{ + "cursor": []string{"xxx"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + + if testCase.expectStatusCode == 0 { + testCase.expectStatusCode = http.StatusOK + } + + expectedCursor := sharedapi.Cursor[core.Log]{ + Data: []core.Log{ + core.NewTransactionLog(core.Transaction{}, map[string]core.Metadata{}), + }, + } - t.Run("invalid end time", func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyEndTime: []string{"invalid time"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidEndTime.Error(), - }, err) - }) + backend, mockLedger := newTestingBackend(t) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + mockLedger.EXPECT(). + GetLogs(gomock.Any(), testCase.expectQuery). + Return(expectedCursor, nil) + } - to := ledgerstore.LogsPaginationToken{} - raw, err := json.Marshal(to) - require.NoError(t, err) + router := routes.NewRouter(backend, nil, nil) - t.Run(fmt.Sprintf("valid empty %s", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - }) + req := httptest.NewRequest(http.MethodGet, "/xxx/logs", nil) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() - t.Run(fmt.Sprintf("valid empty %s with any other param is forbidden", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - "after": []string{"1"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("no other query params can be set with '%s'", controllers.QueryKeyCursor), - }, err) - }) + router.ServeHTTP(rec, req) - t.Run(fmt.Sprintf("invalid %s", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyCursor: []string{"invalid"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("invalid '%s' query param", controllers.QueryKeyCursor), - }, err) - }) + require.Equal(t, testCase.expectStatusCode, rec.Code) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + cursor := DecodeCursorResponse[core.Log](t, rec.Body) + + cursorData, err := json.Marshal(cursor) + require.NoError(t, err) + + cursorAsMap := make(map[string]any) + require.NoError(t, json.Unmarshal(cursorData, &cursorAsMap)) + + expectedCursorData, err := json.Marshal(expectedCursor) + require.NoError(t, err) + + expectedCursorAsMap := make(map[string]any) + require.NoError(t, json.Unmarshal(expectedCursorData, &expectedCursorAsMap)) - t.Run(fmt.Sprintf("invalid %s not base64", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyCursor: []string{"@!/"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("invalid '%s' query param", controllers.QueryKeyCursor), - }, err) + require.Equal(t, expectedCursorAsMap, cursorAsMap) + } else { + err := sharedapi.ErrorResponse{} + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } }) - }) + } } diff --git a/pkg/api/controllers/main_test.go b/pkg/api/controllers/main_test.go deleted file mode 100644 index 9488f9c6f..000000000 --- a/pkg/api/controllers/main_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package controllers_test - -import ( - "os" - "testing" - - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/formancehq/stack/libs/go-libs/pgtesting" -) - -func TestMain(t *testing.M) { - if err := pgtesting.CreatePostgresServer(); err != nil { - logging.Error(err) - os.Exit(1) - } - code := t.Run() - if err := pgtesting.DestroyPostgresServer(); err != nil { - logging.Error(err) - } - os.Exit(code) -} diff --git a/pkg/api/controllers/pagination_test.go b/pkg/api/controllers/pagination_test.go deleted file mode 100644 index c6d1aef20..000000000 --- a/pkg/api/controllers/pagination_test.go +++ /dev/null @@ -1,610 +0,0 @@ -package controllers_test - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "math/big" - "net/http" - "net/url" - "testing" - "time" - - "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/api/internal" - "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/storage" - sharedapi "github.com/formancehq/stack/libs/go-libs/api" - "github.com/go-chi/chi/v5" - "github.com/stretchr/testify/require" -) - -// This test makes sense if maxAdditionalTxs < pageSize -const ( - pageSize = 10 - maxTxsPages = 3 - maxAdditionalTxs = 2 -) - -func TestGetPagination(t *testing.T) { - for txsPages := 0; txsPages <= maxTxsPages; txsPages++ { - for additionalTxs := 0; additionalTxs <= maxAdditionalTxs; additionalTxs++ { - t.Run(fmt.Sprintf("%d-pages-%d-additional", txsPages, additionalTxs), func(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - testGetPagination(t, api, storageDriver, txsPages, additionalTxs) - }) - }) - } - } -} - -func testGetPagination(t *testing.T, api chi.Router, storageDriver storage.Driver, txsPages, additionalTxs int) func(ctx context.Context) error { - return func(ctx context.Context) error { - store, _, err := storageDriver.GetLedgerStore(ctx, internal.TestingLedger, true) - require.NoError(t, err) - - numTxs := txsPages*pageSize + additionalTxs - if numTxs > 0 { - for i := 0; i < numTxs; i++ { - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes( - core.NewTransaction(). - WithPostings(core.NewPosting("world", fmt.Sprintf("accounts:%06d", i), "USD", big.NewInt(10))). - WithReference(fmt.Sprintf("ref:%06d", i)), - ))) - } - } - - rsp := internal.CountTransactions(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - require.Equal(t, fmt.Sprintf("%d", numTxs), rsp.Header().Get("Count")) - - numAcc := 0 - if numTxs > 0 { - numAcc = numTxs + 1 // + world account - } - rsp = internal.CountAccounts(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - require.Equal(t, fmt.Sprintf("%d", numAcc), rsp.Header().Get("Count")) - - accPages := numAcc / pageSize - additionalAccs := numAcc % pageSize - - t.Run("transactions", func(t *testing.T) { - var paginationToken string - cursor := &sharedapi.Cursor[core.ExpandedTransaction]{} - - // MOVING FORWARD - for i := 0; i < txsPages; i++ { - - values := url.Values{} - if paginationToken == "" { - values.Set(controllers.QueryKeyPageSize, fmt.Sprintf("%d", pageSize)) - } else { - values.Set(controllers.QueryKeyCursor, paginationToken) - } - - rsp = internal.GetTransactions(api, values) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor = internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First txid of the page - require.Equal(t, - uint64((txsPages-i)*pageSize+additionalTxs-1), cursor.Data[0].ID) - - // Last txid of the page - require.Equal(t, - uint64((txsPages-i-1)*pageSize+additionalTxs), cursor.Data[len(cursor.Data)-1].ID) - - paginationToken = cursor.Next - } - - if additionalTxs > 0 { - rsp = internal.GetTransactions(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - cursor = internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - require.Len(t, cursor.Data, additionalTxs) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First txid of the last page - require.Equal(t, - uint64(additionalTxs-1), cursor.Data[0].ID) - - // Last txid of the last page - require.Equal(t, - uint64(0), cursor.Data[len(cursor.Data)-1].ID) - } - - require.Empty(t, cursor.Next) - - // MOVING BACKWARD - if txsPages > 0 { - back := 0 - for cursor.Previous != "" { - paginationToken = cursor.Previous - rsp = internal.GetTransactions(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor = internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - back++ - } - if additionalTxs > 0 { - require.Equal(t, txsPages, back) - } else { - require.Equal(t, txsPages-1, back) - } - - // First txid of the first page - require.Equal(t, - uint64(txsPages*pageSize+additionalTxs-1), cursor.Data[0].ID) - - // Last txid of the first page - require.Equal(t, - uint64((txsPages-1)*pageSize+additionalTxs), cursor.Data[len(cursor.Data)-1].ID) - } - - require.Empty(t, cursor.Previous) - }) - - t.Run("accounts", func(t *testing.T) { - var paginationToken string - cursor := &sharedapi.Cursor[core.Account]{} - - // MOVING FORWARD - for i := 0; i < accPages; i++ { - - values := url.Values{} - if paginationToken == "" { - values.Set(controllers.QueryKeyPageSize, fmt.Sprintf("%d", pageSize)) - } else { - values.Set(controllers.QueryKeyCursor, paginationToken) - } - - rsp = internal.GetAccounts(api, values) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor = internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First account of the page - if i == 0 { - require.Equal(t, "world", - string(cursor.Data[0].Address)) - } else { - require.Equal(t, - fmt.Sprintf("accounts:%06d", (accPages-i)*pageSize+additionalAccs-1), - string(cursor.Data[0].Address)) - } - - // Last account of the page - require.Equal(t, - fmt.Sprintf("accounts:%06d", (accPages-i-1)*pageSize+additionalAccs), - string(cursor.Data[len(cursor.Data)-1].Address)) - - paginationToken = cursor.Next - } - - if additionalAccs > 0 { - rsp = internal.GetAccounts(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - cursor = internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, additionalAccs) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First account of the last page - if accPages == 0 { - require.Equal(t, "world", - string(cursor.Data[0].Address)) - } else { - require.Equal(t, - fmt.Sprintf("accounts:%06d", additionalAccs-1), - string(cursor.Data[0].Address)) - } - - // Last account of the last page - require.Equal(t, - fmt.Sprintf("accounts:%06d", 0), - string(cursor.Data[len(cursor.Data)-1].Address)) - } - - require.Empty(t, cursor.Next) - - // MOVING BACKWARD - if accPages > 0 { - back := 0 - for cursor.Previous != "" { - paginationToken = cursor.Previous - rsp = internal.GetAccounts(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - cursor = internal.DecodeCursorResponse[core.Account](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - back++ - } - if additionalAccs > 0 { - require.Equal(t, accPages, back) - } else { - require.Equal(t, accPages-1, back) - } - - // First account of the first page - require.Equal(t, "world", - string(cursor.Data[0].Address)) - - // Last account of the first page - require.Equal(t, - fmt.Sprintf("accounts:%06d", (txsPages-1)*pageSize+additionalTxs+1), - string(cursor.Data[len(cursor.Data)-1].Address)) - } - - require.Empty(t, cursor.Previous) - }) - - t.Run("balances", func(t *testing.T) { - var paginationToken string - cursor := &sharedapi.Cursor[core.AccountsBalances]{} - - // MOVING FORWARD - for i := 0; i < accPages; i++ { - - values := url.Values{} - if paginationToken == "" { - values.Set(controllers.QueryKeyPageSize, fmt.Sprintf("%d", pageSize)) - } else { - values.Set(controllers.QueryKeyCursor, paginationToken) - } - - rsp = internal.GetBalances(api, values) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor = internal.DecodeCursorResponse[core.AccountsBalances](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First account balances of the page - if i == 0 { - _, ok := cursor.Data[0]["world"] - require.True(t, ok) - } else { - _, ok := cursor.Data[0][fmt.Sprintf( - "accounts:%06d", (accPages-i)*pageSize+additionalAccs-1)] - require.True(t, ok) - } - - // Last account balances of the page - _, ok := cursor.Data[len(cursor.Data)-1][fmt.Sprintf( - "accounts:%06d", (accPages-i-1)*pageSize+additionalAccs)] - require.True(t, ok) - - paginationToken = cursor.Next - } - - if additionalAccs > 0 { - rsp = internal.GetBalances(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - cursor = internal.DecodeCursorResponse[core.AccountsBalances](t, rsp.Body) - require.Len(t, cursor.Data, additionalAccs) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First account balances of the last page - if accPages == 0 { - _, ok := cursor.Data[0]["world"] - require.True(t, ok) - } else { - _, ok := cursor.Data[0][fmt.Sprintf( - "accounts:%06d", additionalAccs-1)] - require.True(t, ok) - } - - // Last account balances of the last page - _, ok := cursor.Data[len(cursor.Data)-1][fmt.Sprintf( - "accounts:%06d", 0)] - require.True(t, ok) - } - - // MOVING BACKWARD - if accPages > 0 { - back := 0 - for cursor.Previous != "" { - paginationToken = cursor.Previous - rsp = internal.GetBalances(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - cursor = internal.DecodeCursorResponse[core.AccountsBalances](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - back++ - } - if additionalAccs > 0 { - require.Equal(t, accPages, back) - } else { - require.Equal(t, accPages-1, back) - } - - // First account balances of the first page - _, ok := cursor.Data[0]["world"] - require.True(t, ok) - - // Last account balances of the first page - _, ok = cursor.Data[len(cursor.Data)-1][fmt.Sprintf( - "accounts:%06d", (txsPages-1)*pageSize+additionalTxs+1)] - require.True(t, ok) - } - }) - - t.Run("logs", func(t *testing.T) { - var paginationToken string - cursor := &sharedapi.Cursor[core.Log]{} - - // MOVING FORWARD - for i := 0; i < txsPages; i++ { - - values := url.Values{} - if paginationToken == "" { - values.Set(controllers.QueryKeyPageSize, fmt.Sprintf("%d", pageSize)) - } else { - values.Set(controllers.QueryKeyCursor, paginationToken) - } - - rsp = internal.GetLedgerLogs(api, values) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor = internal.DecodeCursorResponse[core.Log](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First ID of the page - require.Equal(t, - uint64((txsPages-i)*pageSize+additionalTxs-1), cursor.Data[0].ID) - - // Last ID of the page - require.Equal(t, - uint64((txsPages-i-1)*pageSize+additionalTxs), cursor.Data[len(cursor.Data)-1].ID) - - paginationToken = cursor.Next - } - - if additionalTxs > 0 { - rsp = internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - cursor = internal.DecodeCursorResponse[core.Log](t, rsp.Body) - require.Len(t, cursor.Data, additionalTxs) - require.Equal(t, cursor.Next != "", cursor.HasMore) - - // First ID of the last page - require.Equal(t, - uint64(additionalTxs-1), cursor.Data[0].ID) - - // Last ID of the last page - require.Equal(t, - uint64(0), cursor.Data[len(cursor.Data)-1].ID) - } - - require.Empty(t, cursor.Next) - - // MOVING BACKWARD - if txsPages > 0 { - back := 0 - for cursor.Previous != "" { - paginationToken = cursor.Previous - rsp = internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyCursor: []string{paginationToken}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor = internal.DecodeCursorResponse[core.Log](t, rsp.Body) - require.Len(t, cursor.Data, pageSize) - require.Equal(t, cursor.Next != "", cursor.HasMore) - back++ - } - if additionalTxs > 0 { - require.Equal(t, txsPages, back) - } else { - require.Equal(t, txsPages-1, back) - } - - // First ID of the first page - require.Equal(t, - uint64(txsPages*pageSize+additionalTxs-1), cursor.Data[0].ID) - - // Last ID of the first page - require.Equal(t, - uint64((txsPages-1)*pageSize+additionalTxs), cursor.Data[len(cursor.Data)-1].ID) - } - - require.Empty(t, cursor.Previous) - }) - - return nil - } -} - -func TestCursor(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - timestamp, err := core.ParseTime("2023-01-01T00:00:00Z") - require.NoError(t, err) - - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - - for i := 0; i < 30; i++ { - date := timestamp.Add(time.Duration(i) * time.Second) - tx := core.NewTransaction(). - WithPostings(core.NewPosting("world", fmt.Sprintf("accounts:%02d", i), "USD", big.NewInt(1))). - WithReference(fmt.Sprintf("ref:%02d", i)). - WithMetadata(core.Metadata{"ref": "abc"}). - WithTimestamp(date). - WithID(uint64(i)) - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes(tx))) - log := core.NewTransactionLog(tx, nil).WithDate(date) - require.NoError(t, store.AppendLog(context.Background(), &log)) - require.NoError(t, store.EnsureAccountExists(context.Background(), fmt.Sprintf("accounts:%02d", i))) - require.NoError(t, store.UpdateAccountMetadata(context.Background(), fmt.Sprintf("accounts:%02d", i), core.Metadata{ - "foo": json.RawMessage(`"bar"`), - })) - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - fmt.Sprintf("accounts:%02d", i): { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(1)), - }, - })) - } - - t.Run("GetAccounts", func(t *testing.T) { - httpResponse := internal.GetAccounts(api, url.Values{ - "after": []string{"accounts:15"}, - "address": []string{"acc.*"}, - "metadata[foo]": []string{"bar"}, - "balance": []string{"1"}, - controllers.QueryKeyBalanceOperator: []string{"gte"}, - controllers.QueryKeyPageSize: []string{"3"}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor := internal.DecodeCursorResponse[core.Account](t, httpResponse.Body) - res, err := base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"pageSize":3,"offset":3,"after":"accounts:15","address":"acc.*","metadata":{"foo":"bar"},"balance":"1","balanceOperator":"gte"}`, - string(res)) - - httpResponse = internal.GetAccounts(api, url.Values{ - controllers.QueryKeyCursor: []string{cursor.Next}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor = internal.DecodeCursorResponse[core.Account](t, httpResponse.Body) - res, err = base64.RawURLEncoding.DecodeString(cursor.Previous) - require.NoError(t, err) - require.Equal(t, - `{"pageSize":3,"offset":0,"after":"accounts:15","address":"acc.*","metadata":{"foo":"bar"},"balance":"1","balanceOperator":"gte"}`, - string(res)) - res, err = base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"pageSize":3,"offset":6,"after":"accounts:15","address":"acc.*","metadata":{"foo":"bar"},"balance":"1","balanceOperator":"gte"}`, - string(res)) - }) - - t.Run("GetTransactions", func(t *testing.T) { - httpResponse := internal.GetTransactions(api, url.Values{ - "after": []string{"15"}, - "account": []string{"acc.*"}, - "source": []string{"world"}, - "destination": []string{"acc.*"}, - controllers.QueryKeyStartTime: []string{timestamp.Add(5 * time.Second).Format(time.RFC3339)}, - controllers.QueryKeyEndTime: []string{timestamp.Add(25 * time.Second).Format(time.RFC3339)}, - "metadata[ref]": []string{"abc"}, - controllers.QueryKeyPageSize: []string{"3"}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor := internal.DecodeCursorResponse[core.Transaction](t, httpResponse.Body) - res, err := base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"after":12,"account":"acc.*","source":"world","destination":"acc.*","startTime":"2023-01-01T00:00:05Z","endTime":"2023-01-01T00:00:25Z","metadata":{"ref":"abc"},"pageSize":3}`, - string(res)) - - httpResponse = internal.GetTransactions(api, url.Values{ - controllers.QueryKeyCursor: []string{cursor.Next}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor = internal.DecodeCursorResponse[core.Transaction](t, httpResponse.Body) - res, err = base64.RawURLEncoding.DecodeString(cursor.Previous) - require.NoError(t, err) - require.Equal(t, - `{"after":15,"account":"acc.*","source":"world","destination":"acc.*","startTime":"2023-01-01T00:00:05Z","endTime":"2023-01-01T00:00:25Z","metadata":{"ref":"abc"},"pageSize":3}`, - string(res)) - res, err = base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"after":9,"account":"acc.*","source":"world","destination":"acc.*","startTime":"2023-01-01T00:00:05Z","endTime":"2023-01-01T00:00:25Z","metadata":{"ref":"abc"},"pageSize":3}`, - string(res)) - }) - - t.Run("GetBalances", func(t *testing.T) { - httpResponse := internal.GetBalances(api, url.Values{ - "after": []string{"accounts:15"}, - "address": []string{"acc.*"}, - controllers.QueryKeyPageSize: []string{"3"}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor := internal.DecodeCursorResponse[core.AccountsBalances](t, httpResponse.Body) - res, err := base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"pageSize":3,"offset":3,"after":"accounts:15","address":"acc.*"}`, - string(res)) - - httpResponse = internal.GetBalances(api, url.Values{ - controllers.QueryKeyCursor: []string{cursor.Next}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor = internal.DecodeCursorResponse[core.AccountsBalances](t, httpResponse.Body) - res, err = base64.RawURLEncoding.DecodeString(cursor.Previous) - require.NoError(t, err) - require.Equal(t, - `{"pageSize":3,"offset":0,"after":"accounts:15","address":"acc.*"}`, - string(res)) - res, err = base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"pageSize":3,"offset":6,"after":"accounts:15","address":"acc.*"}`, - string(res)) - }) - - t.Run("GetLogs", func(t *testing.T) { - httpResponse := internal.GetLedgerLogs(api, url.Values{ - "after": []string{"30"}, - controllers.QueryKeyStartTime: []string{timestamp.Add(5 * time.Second).Format(time.RFC3339)}, - controllers.QueryKeyEndTime: []string{timestamp.Add(25 * time.Second).Format(time.RFC3339)}, - controllers.QueryKeyPageSize: []string{"2"}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor := internal.DecodeCursorResponse[core.Log](t, httpResponse.Body) - res, err := base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"after":23,"pageSize":2,"startTime":"2023-01-01T00:00:05Z","endTime":"2023-01-01T00:00:25Z"}`, - string(res)) - - httpResponse = internal.GetLedgerLogs(api, url.Values{ - controllers.QueryKeyCursor: []string{cursor.Next}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) - - cursor = internal.DecodeCursorResponse[core.Log](t, httpResponse.Body) - res, err = base64.RawURLEncoding.DecodeString(cursor.Previous) - require.NoError(t, err) - require.Equal(t, - `{"after":25,"pageSize":2,"startTime":"2023-01-01T00:00:05Z","endTime":"2023-01-01T00:00:25Z"}`, - string(res)) - - res, err = base64.RawURLEncoding.DecodeString(cursor.Next) - require.NoError(t, err) - require.Equal(t, - `{"after":21,"pageSize":2,"startTime":"2023-01-01T00:00:05Z","endTime":"2023-01-01T00:00:25Z"}`, - string(res)) - }) - }) -} diff --git a/pkg/api/controllers/transaction_controller.go b/pkg/api/controllers/transaction_controller.go index b45f9ca88..9ddf924d0 100644 --- a/pkg/api/controllers/transaction_controller.go +++ b/pkg/api/controllers/transaction_controller.go @@ -46,7 +46,8 @@ func CountTransactions(w http.ResponseWriter, r *http.Request) { WithSourceFilter(r.URL.Query().Get("source")). WithDestinationFilter(r.URL.Query().Get("destination")). WithStartTimeFilter(startTimeParsed). - WithEndTimeFilter(endTimeParsed) + WithEndTimeFilter(endTimeParsed). + WithMetadataFilter(sharedapi.GetQueryMap(r.URL.Query(), "metadata")) count, err := l.CountTransactions(r.Context(), *txQuery) if err != nil { @@ -83,7 +84,7 @@ func GetTransactions(w http.ResponseWriter, r *http.Request) { return } - token := ledgerstore.TxsPaginationToken{} + token := ledgerstore.TransactionsPaginationToken{} if err = json.Unmarshal(res, &token); err != nil { apierrors.ResponseError(w, r, errorsutil.NewError(ledger.ErrValidation, errors.Errorf("invalid '%s' query param", QueryKeyCursor))) @@ -278,12 +279,6 @@ func PostTransactionMetadata(w http.ResponseWriter, r *http.Request) { return } - _, err = l.GetTransaction(r.Context(), txId) - if err != nil { - apierrors.ResponseError(w, r, err) - return - } - if err := l.SaveMeta(r.Context(), core.MetaTargetTypeTransaction, txId, m); err != nil { apierrors.ResponseError(w, r, err) return diff --git a/pkg/api/controllers/transaction_controller_test.go b/pkg/api/controllers/transaction_controller_test.go index 3335b0e6d..408300f00 100644 --- a/pkg/api/controllers/transaction_controller_test.go +++ b/pkg/api/controllers/transaction_controller_test.go @@ -1,414 +1,109 @@ package controllers_test import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" "math/big" "net/http" + "net/http/httptest" "net/url" "testing" - "time" "github.com/formancehq/ledger/pkg/api/apierrors" "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/api/internal" + "github.com/formancehq/ledger/pkg/api/routes" "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/storage" ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" sharedapi "github.com/formancehq/stack/libs/go-libs/api" - "github.com/go-chi/chi/v5" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) func TestPostTransactions(t *testing.T) { type testCase struct { - name string - initialTransactions []core.Transaction - payload controllers.PostTransactionRequest - expectedStatusCode int - expectedRes sharedapi.BaseResponse[[]core.ExpandedTransaction] - expectedErr sharedapi.ErrorResponse + name string + expectedDryRun bool + expectedRunScript core.RunScript + payload any + expectedStatusCode int + expectedErrorCode string + queryParams url.Values } - //var timestamp1 = core.Now().Add(1 * time.Minute) - var timestamp2 = core.Now().Add(2 * time.Minute) - var timestamp3 = core.Now().Add(3 * time.Minute) - testCases := []testCase{ { - name: "postings nominal", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "USB", - }, - }, - }, - expectedStatusCode: http.StatusOK, - expectedRes: sharedapi.BaseResponse[[]core.ExpandedTransaction]{ - Data: &[]core.ExpandedTransaction{{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "USB", - }, - }, - }, - }, - }}, - }, - }, - { - name: "postings asset with digit", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "US1234D", - }, - }, - }, - expectedStatusCode: http.StatusOK, - expectedRes: sharedapi.BaseResponse[[]core.ExpandedTransaction]{ - Data: &[]core.ExpandedTransaction{{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "US1234D", - }, - }, - }, - }, - }}, - }, - }, - { - name: "script nominal", + name: "using plain numscript", payload: controllers.PostTransactionRequest{ Script: core.Script{ - Plain: ` - vars { - account $acc - } - send [COIN 100] ( - source = @world - destination = @centralbank - ) - send [COIN 100] ( - source = @centralbank - destination = $acc + Plain: `send ( + source = @world + destination = @bank )`, - Vars: map[string]json.RawMessage{ - "acc": json.RawMessage(`"users:001"`), - }, }, }, - expectedStatusCode: http.StatusOK, - expectedRes: sharedapi.BaseResponse[[]core.ExpandedTransaction]{ - Data: &[]core.ExpandedTransaction{{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "centralbank", - Amount: big.NewInt(100), - Asset: "COIN", - }, - { - Source: "centralbank", - Destination: "users:001", - Amount: big.NewInt(100), - Asset: "COIN", - }, - }, - }, - }, - }}, - }, - }, - { - name: "script with set_account_meta", - payload: controllers.PostTransactionRequest{ + expectedRunScript: core.RunScript{ Script: core.Script{ - Plain: ` - send [TOK 1000] ( - source = @world - destination = @bar - ) - set_account_meta(@bar, "foo", "bar") - `, - }, - }, - expectedStatusCode: http.StatusOK, - expectedRes: sharedapi.BaseResponse[[]core.ExpandedTransaction]{ - Data: &[]core.ExpandedTransaction{{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bar", - Amount: big.NewInt(1000), - Asset: "TOK", - }, - }, - }, - }, - }}, - }, - }, - { - name: "no postings or script", - payload: controllers.PostTransactionRequest{}, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid payload: should contain either postings or script", - }, - }, - { - name: "postings negative amount", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(-1000), - Asset: "USB", - }, - }, - }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid posting 0: negative amount", - }, - }, - { - name: "postings wrong asset with symbol", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "@TOK", - }, - }, - }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid posting 0: invalid asset", - }, - }, - { - name: "postings wrong asset with digit as first char", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "1TOK", - }, - }, - }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid posting 0: invalid asset", - }, - }, - { - name: "postings bad address", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "#fake", - Amount: big.NewInt(1000), - Asset: "TOK", - }, - }, - }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid posting 0: invalid destination address", - }, - }, - { - name: "postings insufficient funds", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "foo", - Destination: "bar", - Amount: big.NewInt(1000), - Asset: "TOK", - }, - }, - }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrInsufficientFund, - ErrorMessage: "[INSUFFICIENT_FUND] account had insufficient funds", - }, - }, - { - name: "postings reference conflict", - initialTransactions: []core.Transaction{{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bar", - Amount: big.NewInt(1000), - Asset: "TOK", - }, - }, - Reference: "ref", - }, - }}, - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bar", - Amount: big.NewInt(1000), - Asset: "TOK", - }, + Plain: `send ( + source = @world + destination = @bank + )`, }, - Reference: "ref", - }, - expectedStatusCode: http.StatusConflict, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrConflict, - ErrorMessage: "conflict error on reference", }, }, { - name: "script failure with insufficient funds", + name: "using plain numscript and dry run", payload: controllers.PostTransactionRequest{ Script: core.Script{ - Plain: ` - send [COIN 100] ( - source = @centralbank - destination = @users:001 + Plain: `send ( + source = @world + destination = @bank )`, }, }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrInsufficientFund, - ErrorMessage: "[INSUFFICIENT_FUND] account had insufficient funds", - Details: apierrors.EncodeLink("account had insufficient funds"), - }, - }, - { - name: "script failure with metadata override", - payload: controllers.PostTransactionRequest{ + expectedRunScript: core.RunScript{ Script: core.Script{ - Plain: ` - set_tx_meta("priority", "low") - - send [USD/2 99] ( - source=@world - destination=@user:001 + Plain: `send ( + source = @world + destination = @bank )`, }, - Metadata: core.Metadata{ - "priority": json.RawMessage(`"high"`), - }, }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrScriptMetadataOverride, - ErrorMessage: "[METADATA_OVERRIDE] cannot override metadata from script", - Details: apierrors.EncodeLink("cannot override metadata from script"), + expectedDryRun: true, + queryParams: url.Values{ + "preview": []string{"true"}, }, }, { - name: "script failure with no postings", + name: "using JSON postings", payload: controllers.PostTransactionRequest{ - Script: core.Script{ - Plain: ` - set_account_meta(@bar, "foo", "bar") - `, + Postings: []core.Posting{ + core.NewPosting("world", "bank", "USD", big.NewInt(100)), }, }, - expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "transaction has no postings", - }, + expectedRunScript: core.TxToScriptData(core.NewTransactionData().WithPostings( + core.NewPosting("world", "bank", "USD", big.NewInt(100)), + )), }, - //{ - // name: "script failure with invalid account variable", - // payload: controllers.PostTransactionRequest{ - // Script: core.Script{ - // Plain: ` - // vars { - // account $acc - // } - // send [USD/2 99] ( - // source = @world - // destination = $acc - // ) - // `, - // Vars: map[string]json.RawMessage{ - // "acc": json.RawMessage(`"invalid-acc"`), - // }, - // }, - // }, - // expectedStatusCode: http.StatusBadRequest, - // expectedErr: sharedapi.ErrorResponse{ - // ErrorCode: apierrors.ErrScriptCompilationFailed, - // ErrorMessage: "[COMPILATION_FAILED] value invalid-acc: accounts should respect pattern ^[a-zA-Z_]+[a-zA-Z0-9_:]*$", - // Details: apierrors.EncodeLink("value invalid-acc: accounts should respect pattern ^[a-zA-Z_]+[a-zA-Z0-9_:]*$"), - // }, - //}, { - name: "script failure with invalid monetary variable", + name: "using JSON postings and dry run", + queryParams: url.Values{ + // TODO(gfyrag): Rename to dry run + "preview": []string{"true"}, + }, payload: controllers.PostTransactionRequest{ - Script: core.Script{ - Plain: ` - vars { - monetary $mon - } - send $mon ( - source = @world - destination = @alice - ) - `, - Vars: map[string]json.RawMessage{ - "mon": json.RawMessage(`{"asset": "COIN","amount":-1}`), - }, + Postings: []core.Posting{ + core.NewPosting("world", "bank", "USD", big.NewInt(100)), }, }, + expectedDryRun: true, + expectedRunScript: core.TxToScriptData(core.NewTransactionData().WithPostings( + core.NewPosting("world", "bank", "USD", big.NewInt(100)), + )), + }, + { + name: "no postings or script", + payload: controllers.PostTransactionRequest{}, expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrScriptCompilationFailed, - ErrorMessage: "[COMPILATION_FAILED] could not set variables: invalid JSON value for variable $mon of type monetary: value [COIN -1]: negative amount", - Details: apierrors.EncodeLink("invalid JSON value for variable $mon of type monetary: value [COIN -1]: negative amount"), - }, + expectedErrorCode: apierrors.ErrValidation, }, { name: "postings and script", @@ -430,1013 +125,514 @@ func TestPostTransactions(t *testing.T) { }, }, expectedStatusCode: http.StatusBadRequest, - expectedErr: sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid payload: should contain either postings or script", - }, - }, - { - name: "postings with specified timestamp", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bar", - Amount: big.NewInt(1000), - Asset: "TOK", - }, - }, - Timestamp: timestamp2, - }, - expectedStatusCode: http.StatusOK, - expectedRes: sharedapi.BaseResponse[[]core.ExpandedTransaction]{ - Data: &[]core.ExpandedTransaction{{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bar", - Amount: big.NewInt(1000), - Asset: "TOK", - }, - }, - }, - }, - }}, - }, + expectedErrorCode: apierrors.ErrValidation, }, { - name: "script with specified timestamp", - payload: controllers.PostTransactionRequest{ - Script: core.Script{ - Plain: ` - send [TOK 1000] ( - source = @world - destination = @bar - ) - `, - }, - Timestamp: timestamp3, - }, - expectedStatusCode: http.StatusOK, - expectedRes: sharedapi.BaseResponse[[]core.ExpandedTransaction]{ - Data: &[]core.ExpandedTransaction{{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bar", - Amount: big.NewInt(1000), - Asset: "TOK", - }, - }, - }, - }, - }}, - }, - }, - //{ - // name: "postings with specified timestamp prior to last tx", - // initialTransactions: []core.Transaction{{ - // TransactionData: core.TransactionData{ - // Postings: core.Postings{ - // { - // Source: "world", - // Destination: "bar", - // Amount: big.NewInt(1000), - // Asset: "TOK", - // }, - // }, - // Timestamp: timestamp2, - // }, - // }}, - // payload: controllers.PostTransactionRequest{ - // Postings: core.Postings{ - // { - // Source: "world", - // Destination: "bar", - // Amount: big.NewInt(1000), - // Asset: "TOK", - // }, - // }, - // Timestamp: timestamp1, - // }, - // expectedStatusCode: http.StatusBadRequest, - // expectedErr: sharedapi.ErrorResponse{ - // ErrorCode: apierrors.ErrValidation, - // ErrorMessage: "cannot pass a timestamp prior to the last transaction:", - // }, - //}, - //{ - // name: "script with specified timestamp prior to last tx", - // initialTransactions: []core.Transaction{ - // core.NewTransaction(). - // WithPostings(core.NewPosting("world", "bob", "COIN", big.NewInt(100))). - // WithTimestamp(timestamp2), - // }, - // payload: controllers.PostTransactionRequest{ - // Script: core.Script{ - // Plain: ` - // send [COIN 100] ( - // source = @world - // destination = @bob - // )`, - // }, - // Timestamp: timestamp1, - // }, - // expectedStatusCode: http.StatusBadRequest, - // expectedErr: sharedapi.ErrorResponse{ - // ErrorCode: apierrors.ErrValidation, - // ErrorMessage: "cannot pass a timestamp prior to the last transaction:", - // }, - //}, - { - name: "short asset", - payload: controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bank", - Amount: big.NewInt(1000), - Asset: "F/9", - }, - }, - Timestamp: timestamp3, - }, - expectedStatusCode: http.StatusOK, - expectedRes: sharedapi.BaseResponse[[]core.ExpandedTransaction]{ - Data: &[]core.ExpandedTransaction{{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "bank", - Amount: big.NewInt(1000), - Asset: "F/9", - }, - }, - Timestamp: timestamp3, - }, - }, - }}, - }, + name: "using invalid body", + payload: "not a valid payload", + expectedStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, }, } - for _, tc := range testCases { - tc := tc + for _, testCase := range testCases { + tc := testCase t.Run(tc.name, func(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) + if testCase.expectedStatusCode == 0 { + testCase.expectedStatusCode = http.StatusOK + } - _, err = store.Initialize(context.Background()) - require.NoError(t, err) + expectedTx := core.ExpandTransaction( + core.NewTransaction().WithPostings( + core.NewPosting("world", "bank", "USD", big.NewInt(100)), + ), + nil, + ) + + backend, mockLedger := newTestingBackend(t) + if testCase.expectedStatusCode < 300 && testCase.expectedStatusCode >= 200 { + mockLedger.EXPECT(). + CreateTransaction(gomock.Any(), testCase.expectedDryRun, testCase.expectedRunScript). + Return(&expectedTx, nil) + } - for _, transaction := range tc.initialTransactions { - log := core.NewTransactionLog(transaction, nil). - WithReference(transaction.Reference) - require.NoError(t, store.AppendLog(context.Background(), &log)) - } + router := routes.NewRouter(backend, nil, nil) - rsp := internal.PostTransaction(t, api, tc.payload, false) - require.Equal(t, tc.expectedStatusCode, rsp.Result().StatusCode, rsp.Body.String()) + req := httptest.NewRequest(http.MethodPost, "/xxx/transactions", Buffer(t, testCase.payload)) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() - if tc.expectedStatusCode != http.StatusOK { - actualErr := sharedapi.ErrorResponse{} - if internal.Decode(t, rsp.Body, &actualErr) { - require.Equal(t, tc.expectedErr.ErrorCode, actualErr.ErrorCode, actualErr.ErrorMessage) - if tc.expectedErr.Details != "" { - require.Equal(t, tc.expectedErr.Details, actualErr.Details) - } - } - } else { - txs, ok := internal.DecodeSingleResponse[core.ExpandedTransaction](t, rsp.Body) - require.True(t, ok) - require.Equal(t, (*tc.expectedRes.Data)[0].Postings, txs.Postings) - require.Equal(t, len((*tc.expectedRes.Data)[0].Metadata), len(txs.Metadata)) + router.ServeHTTP(rec, req) - if !tc.payload.Timestamp.IsZero() { - require.Equal(t, tc.payload.Timestamp, txs.Timestamp) - } - } - }) + require.Equal(t, testCase.expectedStatusCode, rec.Code) + if testCase.expectedStatusCode < 300 && testCase.expectedStatusCode >= 200 { + tx, ok := DecodeSingleResponse[core.ExpandedTransaction](t, rec.Body) + require.True(t, ok) + require.Equal(t, expectedTx, tx) + } else { + err := sharedapi.ErrorResponse{} + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } }) } } -func TestPostTransactionsPreview(t *testing.T) { - script := ` - send [COIN 100] ( - source = @world - destination = @centralbank - )` - - internal.RunTest(t, func(api chi.Router, driver storage.Driver) { - store := internal.GetLedgerStore(t, driver, context.Background()) - t.Run("postings true", func(t *testing.T) { - rsp := internal.PostTransaction(t, api, controllers.PostTransactionRequest{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "USD", - }, - }, - }, true) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - _, ok := internal.DecodeSingleResponse[core.ExpandedTransaction](t, rsp.Body) - require.True(t, ok) - - cursor, err := store.GetTransactions(context.Background(), *storage.NewTransactionsQuery()) - require.NoError(t, err) - require.Len(t, cursor.Data, 0) - }) - - t.Run("script true", func(t *testing.T) { - rsp := internal.PostTransaction(t, api, controllers.PostTransactionRequest{ - Script: core.Script{ - Plain: script, - }, - }, true) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - _, ok := internal.DecodeSingleResponse[core.ExpandedTransaction](t, rsp.Body) - require.True(t, ok) - - cursor, err := store.GetTransactions(context.Background(), *storage.NewTransactionsQuery()) - require.NoError(t, err) - require.Len(t, cursor.Data, 0) - }) - }) -} - -func TestPostTransactionInvalidBody(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - t.Run("no JSON", func(t *testing.T) { - rsp := internal.NewPostOnLedger(t, api, "/transactions", "invalid") - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid transaction format", - }, err) - }) - - t.Run("JSON without postings", func(t *testing.T) { - rsp := internal.NewPostOnLedger(t, api, "/transactions", core.Account{Address: "addr"}) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid payload: should contain either postings or script", - }, err) - }) - }) -} - func TestPostTransactionMetadata(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes( - core.NewTransaction().WithPostings( - core.NewPosting("world", "central_bank", "USD", big.NewInt(1000)), - ), - ))) + t.Parallel() - t.Run("valid", func(t *testing.T) { - rsp := internal.PostTransactionMetadata(t, api, 0, core.Metadata{ - "foo": json.RawMessage(`"bar"`), - }) - require.Equal(t, http.StatusNoContent, rsp.Result().StatusCode) - }) + type testCase struct { + name string + queryParams url.Values + expectStatusCode int + expectedErrorCode string + body any + } - t.Run("different metadata on same key should replace it", func(t *testing.T) { - rsp := internal.PostTransactionMetadata(t, api, 0, core.Metadata{ - "foo": "baz", - }) - require.Equal(t, http.StatusNoContent, rsp.Result().StatusCode) - }) + testCases := []testCase{ + { + name: "nominal", + body: core.Metadata{ + "foo": "bar", + }, + }, + { + name: "invalid body", + body: "invalid - not an object", + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { - t.Run("transaction not found", func(t *testing.T) { - rsp := internal.PostTransactionMetadata(t, api, 42, core.Metadata{ - "foo": "baz", - }) - require.Equal(t, http.StatusNotFound, rsp.Result().StatusCode) + if testCase.expectStatusCode == 0 { + testCase.expectStatusCode = http.StatusNoContent + } - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrNotFound, - ErrorMessage: "not found", - }, err) - }) + backend, mock := newTestingBackend(t) + if testCase.expectStatusCode == http.StatusNoContent { + mock.EXPECT(). + SaveMeta(gomock.Any(), core.MetaTargetTypeTransaction, uint64(0), testCase.body). + Return(nil) + } - t.Run("no JSON", func(t *testing.T) { - rsp := internal.NewPostOnLedger(t, api, "/transactions/0/metadata", "invalid") - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) + router := routes.NewRouter(backend, nil, nil) - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid metadata format", - }, err) - }) + req := httptest.NewRequest(http.MethodPost, "/xxx/transactions/0/metadata", Buffer(t, testCase.body)) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() - t.Run("invalid txid", func(t *testing.T) { - rsp := internal.NewPostOnLedger(t, api, "/transactions/invalid/metadata", core.Metadata{ - "foo": json.RawMessage(`"bar"`), - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) + router.ServeHTTP(rec, req) - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid transaction ID", - }, err) + require.Equal(t, testCase.expectStatusCode, rec.Code) + if testCase.expectStatusCode >= 300 || testCase.expectStatusCode < 200 { + err := sharedapi.ErrorResponse{} + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } }) - }) + } } func TestGetTransaction(t *testing.T) { - internal.RunTest(t, func(api chi.Router, storageDriver storage.Driver) { - - store, _, err := storageDriver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) + t.Parallel() - _, err = store.Initialize(context.Background()) - require.NoError(t, err) + tx := core.ExpandTransaction( + core.NewTransaction().WithPostings( + core.NewPosting("world", "bank", "USD", big.NewInt(100)), + ), + nil, + ) - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes( - core.NewTransaction(). - WithPostings(core.NewPosting("world", "central_bank", "USD", big.NewInt(1000))). - WithReference("ref"). - WithTimestamp(core.Now()), - ))) + backend, mock := newTestingBackend(t) + mock.EXPECT(). + GetTransaction(gomock.Any(), uint64(0)). + Return(&tx, nil) - t.Run("valid txid", func(t *testing.T) { - rsp := internal.GetTransaction(api, 0) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) + router := routes.NewRouter(backend, nil, nil) - ret, _ := internal.DecodeSingleResponse[core.ExpandedTransaction](t, rsp.Body) - require.EqualValues(t, core.Postings{ - { - Source: "world", - Destination: "central_bank", - Amount: big.NewInt(1000), - Asset: "USD", - }, - }, ret.Postings) - require.EqualValues(t, 0, ret.ID) - require.EqualValues(t, core.Metadata{}, ret.Metadata) - require.EqualValues(t, "ref", ret.Reference) - require.NotEmpty(t, ret.Timestamp) - require.EqualValues(t, core.AccountsAssetsVolumes{ - "world": core.AssetsVolumes{ - "USD": { - Input: big.NewInt(0), - Output: big.NewInt(0), - }, - }, - "central_bank": core.AssetsVolumes{ - "USD": { - Input: big.NewInt(0), - Output: big.NewInt(0), - }, - }, - }, ret.PreCommitVolumes) - require.EqualValues(t, core.AccountsAssetsVolumes{ - "world": core.AssetsVolumes{ - "USD": { - Input: big.NewInt(0), - Output: big.NewInt(1000), - }, - }, - "central_bank": core.AssetsVolumes{ - "USD": { - Input: big.NewInt(1000), - Output: big.NewInt(0), - }, - }, - }, ret.PostCommitVolumes) - }) + req := httptest.NewRequest(http.MethodGet, "/xxx/transactions/0", nil) + rec := httptest.NewRecorder() - t.Run("unknown txid", func(t *testing.T) { - rsp := internal.GetTransaction(api, 42) - require.Equal(t, http.StatusNotFound, rsp.Result().StatusCode) + router.ServeHTTP(rec, req) - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrNotFound, - ErrorMessage: "not found", - }, err) - }) + require.Equal(t, http.StatusOK, rec.Code) + response, _ := DecodeSingleResponse[core.ExpandedTransaction](t, rec.Body) + require.Equal(t, tx, response) +} - t.Run("invalid txid", func(t *testing.T) { - rsp := internal.NewGetOnLedger(api, "/transactions/invalid") - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) +func TestGetTransactions(t *testing.T) { + t.Parallel() - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid transaction ID", - }, err) - }) - }) -} + type testCase struct { + name string + queryParams url.Values + expectQuery storage.TransactionsQuery + expectStatusCode int + expectedErrorCode string + } + now := core.Now() -func TestTransactions(t *testing.T) { - internal.RunTest(t, func(api chi.Router, driver storage.Driver) { - now := core.Now() - tx1 := core.ExpandedTransaction{ - Transaction: core.Transaction{ - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank1", - Amount: big.NewInt(1000), - Asset: "USD", - }, - }, - Reference: "ref:001", - Timestamp: now.Add(-3 * time.Hour), - }, + testCases := []testCase{ + { + name: "nominal", + expectQuery: *storage.NewTransactionsQuery(), + }, + { + name: "using metadata", + queryParams: url.Values{ + "metadata[roles]": []string{"admin"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithMetadataFilter(map[string]string{ + "roles": "admin", + }), + }, + { + name: "using nested metadata", + queryParams: url.Values{ + "metadata[a.nested.key]": []string{"hello"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithMetadataFilter(map[string]string{ + "a.nested.key": "hello", + }), + }, + { + name: "using after", + queryParams: url.Values{ + "after": []string{"10"}, }, - } - tx2 := core.ExpandedTransaction{ - Transaction: core.Transaction{ - ID: 1, - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: "central_bank2", - Amount: big.NewInt(1000), - Asset: "USD", - }, - }, - Metadata: core.Metadata{ - "foo": "bar", - }, - Reference: "ref:002", - Timestamp: now.Add(-2 * time.Hour), - }, + expectQuery: *storage.NewTransactionsQuery(). + WithAfterTxID(10), + }, + { + name: "using startTime", + queryParams: url.Values{ + "startTime": []string{now.Format(core.DateFormat)}, }, - } - tx3 := core.ExpandedTransaction{ - Transaction: core.Transaction{ - ID: 2, - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "central_bank1", - Destination: "alice", - Amount: big.NewInt(10), - Asset: "USD", - }, - }, - Reference: "ref:003", - Metadata: core.Metadata{ - "priority": "high", - }, - Timestamp: now.Add(-1 * time.Hour), - }, + expectQuery: *storage.NewTransactionsQuery(). + WithStartTimeFilter(now), + }, + { + name: "using invalid startTime", + queryParams: url.Values{ + "startTime": []string{"xxx"}, }, - } - store := internal.GetLedgerStore(t, driver, context.Background()) - _, err := store.Initialize(context.Background()) - require.NoError(t, err) - - err = store.InsertTransactions(context.Background(), tx1, tx2, tx3) - require.NoError(t, err) - - var tx1Timestamp, tx2Timestamp core.Time - t.Run("Get", func(t *testing.T) { - t.Run("all", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // all transactions - require.Len(t, cursor.Data, 3) - require.Equal(t, cursor.Data[0].ID, uint64(2)) - require.Equal(t, cursor.Data[1].ID, uint64(1)) - require.Equal(t, cursor.Data[2].ID, uint64(0)) - - tx1Timestamp = cursor.Data[1].Timestamp - tx2Timestamp = cursor.Data[0].Timestamp - }) - - t.Run("metadata", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "metadata[priority]": []string{"high"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - - require.Len(t, cursor.Data, 1) - require.Equal(t, cursor.Data[0].ID, tx3.ID) - }) - - t.Run("after", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "after": []string{"1"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // 1 transaction: txid 0 - require.Len(t, cursor.Data, 1) - require.Equal(t, cursor.Data[0].ID, uint64(0)) - }) - - t.Run("invalid after", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "after": []string{"invalid"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid 'after' query param", - }, err) - }) - - t.Run("reference", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "reference": []string{"ref:001"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // 1 transaction: txid 0 - require.Len(t, cursor.Data, 1) - require.Equal(t, cursor.Data[0].ID, uint64(0)) - }) - - t.Run("destination", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "destination": []string{"central_bank1"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // 1 transaction: txid 0 - require.Len(t, cursor.Data, 1) - require.Equal(t, cursor.Data[0].ID, uint64(0)) - }) - - t.Run("source", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "source": []string{"world"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // 2 transactions: txid 0 and txid 1 - require.Len(t, cursor.Data, 2) - require.Equal(t, cursor.Data[0].ID, uint64(1)) - require.Equal(t, cursor.Data[1].ID, uint64(0)) - }) - - t.Run("account", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "account": []string{"world"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // 2 transactions: txid 0 and txid 1 - require.Len(t, cursor.Data, 2) - require.Equal(t, cursor.Data[0].ID, uint64(1)) - require.Equal(t, cursor.Data[1].ID, uint64(0)) - }) - - t.Run("account no result", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "account": []string{"central"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - require.Len(t, cursor.Data, 0) - }) - - t.Run("account regex expr", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - "account": []string{"central.*"}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - require.Len(t, cursor.Data, 3) - }) - - t.Run("time range", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyStartTime: []string{tx1Timestamp.Format(time.RFC3339)}, - controllers.QueryKeyEndTime: []string{tx2Timestamp.Format(time.RFC3339)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // 1 transaction: txid 1 - require.Len(t, cursor.Data, 1) - }) - - t.Run("only start time", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyStartTime: []string{core.Now().Add(time.Second).Format(time.RFC3339)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // no transaction - require.Len(t, cursor.Data, 0) - }) - - t.Run("only end time", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyEndTime: []string{core.Now().Add(time.Second).Format(time.RFC3339)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, rsp.Body) - // all transactions - require.Len(t, cursor.Data, 3) - }) - - t.Run("invalid start time", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyStartTime: []string{"invalid time"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidStartTime.Error(), - }, err) - }) - - t.Run("invalid end time", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyEndTime: []string{"invalid time"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidEndTime.Error(), - }, err) - }) - - t.Run("invalid page size", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyPageSize: []string{"invalid page size"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidPageSize.Error(), - }, err) - }) - - to := ledgerstore.TxsPaginationToken{} - raw, err := json.Marshal(to) - require.NoError(t, err) - - t.Run(fmt.Sprintf("valid empty %s", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode, rsp.Body.String()) - }) - - t.Run(fmt.Sprintf("valid empty %s with any other param is forbidden", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyCursor: []string{base64.RawURLEncoding.EncodeToString(raw)}, - "after": []string{"1"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("no other query params can be set with '%s'", controllers.QueryKeyCursor), - }, err) - }) - - t.Run(fmt.Sprintf("invalid %s", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyCursor: []string{"invalid"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("invalid '%s' query param", controllers.QueryKeyCursor), - }, err) - }) - - t.Run(fmt.Sprintf("invalid %s not base64", controllers.QueryKeyCursor), func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyCursor: []string{"@!/"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using endTime", + queryParams: url.Values{ + "endTime": []string{now.Format(core.DateFormat)}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithEndTimeFilter(now), + }, + { + name: "using invalid endTime", + queryParams: url.Values{ + "endTime": []string{"xxx"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using account", + queryParams: url.Values{ + "account": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithAccountFilter("xxx"), + }, + { + name: "using reference", + queryParams: url.Values{ + "reference": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithReferenceFilter("xxx"), + }, + { + name: "using destination", + queryParams: url.Values{ + "destination": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithDestinationFilter("xxx"), + }, + { + name: "using source", + queryParams: url.Values{ + "source": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithSourceFilter("xxx"), + }, + { + name: "using empty cursor", + queryParams: url.Values{ + "cursor": []string{ledgerstore.TransactionsPaginationToken{}.Encode()}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithMetadataFilter(nil), + }, + { + name: "using cursor with other param", + queryParams: url.Values{ + "cursor": []string{ledgerstore.TransactionsPaginationToken{}.Encode()}, + "after": []string{"foo"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using invalid cursor", + queryParams: url.Values{ + "cursor": []string{"XXX"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "invalid page size", + queryParams: url.Values{ + "pageSize": []string{"nan"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "invalid after", + queryParams: url.Values{ + "after": []string{"nan"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "page size over maximum", + queryParams: url.Values{ + "pageSize": []string{"1000000"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithPageSize(controllers.MaxPageSize). + WithMetadataFilter(map[string]string{}), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: fmt.Sprintf("invalid '%s' query param", controllers.QueryKeyCursor), - }, err) - }) - }) + if testCase.expectStatusCode == 0 { + testCase.expectStatusCode = http.StatusOK + } - t.Run("Count", func(t *testing.T) { - t.Run("all", func(t *testing.T) { - rsp := internal.CountTransactions(api, url.Values{}) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - require.Equal(t, "3", rsp.Header().Get("Count")) - }) + expectedCursor := sharedapi.Cursor[core.ExpandedTransaction]{ + Data: []core.ExpandedTransaction{ + core.ExpandTransaction( + core.NewTransaction().WithPostings( + core.NewPosting("world", "bank", "USD", big.NewInt(100)), + ), + nil, + ), + }, + } - t.Run("time range", func(t *testing.T) { - rsp := internal.CountTransactions(api, url.Values{ - controllers.QueryKeyStartTime: []string{tx1Timestamp.Format(time.RFC3339)}, - controllers.QueryKeyEndTime: []string{tx2Timestamp.Format(time.RFC3339)}, - }) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - require.Equal(t, "1", rsp.Header().Get("Count")) - }) + backend, mockLedger := newTestingBackend(t) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + mockLedger.EXPECT(). + GetTransactions(gomock.Any(), testCase.expectQuery). + Return(expectedCursor, nil) + } - t.Run("invalid start time", func(t *testing.T) { - rsp := internal.CountTransactions(api, url.Values{ - controllers.QueryKeyStartTime: []string{"invalid"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) + router := routes.NewRouter(backend, nil, nil) - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidStartTime.Error(), - }, err) - }) + req := httptest.NewRequest(http.MethodGet, "/xxx/transactions", nil) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() - t.Run("invalid end time", func(t *testing.T) { - rsp := internal.CountTransactions(api, url.Values{ - controllers.QueryKeyEndTime: []string{"invalid"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode) + router.ServeHTTP(rec, req) + require.Equal(t, testCase.expectStatusCode, rec.Code) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + cursor := DecodeCursorResponse[core.ExpandedTransaction](t, rec.Body) + require.Equal(t, expectedCursor, *cursor) + } else { err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidEndTime.Error(), - }, err) - }) + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } }) - }) + } } -func TestGetTransactionsWithPageSize(t *testing.T) { - internal.RunTest(t, func(api chi.Router, driver storage.Driver) { - now := core.Now().UTC() - store := internal.GetLedgerStore(t, driver, context.Background()) +func TestCountTransactions(t *testing.T) { + t.Parallel() - _, err := store.Initialize(context.Background()) - require.NoError(t, err) + type testCase struct { + name string + queryParams url.Values + expectQuery storage.TransactionsQuery + expectStatusCode int + expectedErrorCode string + } + now := core.Now() - //TODO(gfyrag): Refine tests, we don't need to insert 3000 tx to test a behavior - for i := 0; i < 3*controllers.MaxPageSize; i++ { - tx := core.ExpandedTransaction{ - Transaction: core.Transaction{ - ID: uint64(i), - TransactionData: core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Destination: fmt.Sprintf("account:%d", i), - Amount: big.NewInt(1000), - Asset: "USD", - }, - }, - Timestamp: now, - }, - }, + testCases := []testCase{ + { + name: "nominal", + expectQuery: *storage.NewTransactionsQuery(), + }, + { + name: "using metadata", + queryParams: url.Values{ + "metadata[roles]": []string{"admin"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithMetadataFilter(map[string]string{ + "roles": "admin", + }), + }, + { + name: "using nested metadata", + queryParams: url.Values{ + "metadata[a.nested.key]": []string{"hello"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithMetadataFilter(map[string]string{ + "a.nested.key": "hello", + }), + }, + { + name: "using startTime", + queryParams: url.Values{ + "startTime": []string{now.Format(core.DateFormat)}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithStartTimeFilter(now), + }, + { + name: "using invalid startTime", + queryParams: url.Values{ + "startTime": []string{"xxx"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using endTime", + queryParams: url.Values{ + "endTime": []string{now.Format(core.DateFormat)}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithEndTimeFilter(now), + }, + { + name: "using invalid endTime", + queryParams: url.Values{ + "endTime": []string{"xxx"}, + }, + expectStatusCode: http.StatusBadRequest, + expectedErrorCode: apierrors.ErrValidation, + }, + { + name: "using account", + queryParams: url.Values{ + "account": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithAccountFilter("xxx"), + }, + { + name: "using reference", + queryParams: url.Values{ + "reference": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithReferenceFilter("xxx"), + }, + { + name: "using destination", + queryParams: url.Values{ + "destination": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithDestinationFilter("xxx"), + }, + { + name: "using source", + queryParams: url.Values{ + "source": []string{"xxx"}, + }, + expectQuery: *storage.NewTransactionsQuery(). + WithSourceFilter("xxx"), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + + if testCase.expectStatusCode == 0 { + // TODO(gfyrag): Change status code to 204 + testCase.expectStatusCode = http.StatusOK } - require.NoError(t, store.InsertTransactions(context.Background(), tx)) - } - t.Run("invalid page size", func(t *testing.T) { - rsp := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyPageSize: []string{"nan"}, - }) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) + backend, mockLedger := newTestingBackend(t) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + mockLedger.EXPECT(). + CountTransactions(gomock.Any(), testCase.expectQuery). + Return(uint64(10), nil) + } - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: controllers.ErrInvalidPageSize.Error(), - }, err) - }) - t.Run("page size over maximum", func(t *testing.T) { - httpResponse := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyPageSize: []string{fmt.Sprintf("%d", 2*controllers.MaxPageSize)}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) + router := routes.NewRouter(backend, nil, nil) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, httpResponse.Body) - require.Len(t, cursor.Data, controllers.MaxPageSize) - require.Equal(t, cursor.PageSize, controllers.MaxPageSize) - require.NotEmpty(t, cursor.Next) - require.True(t, cursor.HasMore) - }) - t.Run("with page size greater than max count", func(t *testing.T) { - httpResponse := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyPageSize: []string{fmt.Sprintf("%d", controllers.MaxPageSize)}, - "after": []string{fmt.Sprintf("%d", controllers.MaxPageSize-100)}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) + req := httptest.NewRequest(http.MethodHead, "/xxx/transactions", nil) + rec := httptest.NewRecorder() + req.URL.RawQuery = testCase.queryParams.Encode() - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, httpResponse.Body) - require.Len(t, cursor.Data, controllers.MaxPageSize-100) - require.Equal(t, cursor.PageSize, controllers.MaxPageSize) - require.Empty(t, cursor.Next) - require.False(t, cursor.HasMore) - }) - t.Run("with page size lower than max count", func(t *testing.T) { - httpResponse := internal.GetTransactions(api, url.Values{ - controllers.QueryKeyPageSize: []string{fmt.Sprintf("%d", controllers.MaxPageSize/10)}, - }) - require.Equal(t, http.StatusOK, httpResponse.Result().StatusCode, httpResponse.Body.String()) + router.ServeHTTP(rec, req) - cursor := internal.DecodeCursorResponse[core.ExpandedTransaction](t, httpResponse.Body) - require.Len(t, cursor.Data, controllers.MaxPageSize/10) - require.Equal(t, cursor.PageSize, controllers.MaxPageSize/10) - require.NotEmpty(t, cursor.Next) - require.True(t, cursor.HasMore) + require.Equal(t, testCase.expectStatusCode, rec.Code) + if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { + require.Equal(t, "10", rec.Header().Get("Count")) + } else { + err := sharedapi.ErrorResponse{} + Decode(t, rec.Body, &err) + require.EqualValues(t, testCase.expectedErrorCode, err.ErrorCode) + } }) - }) + } } func TestRevertTransaction(t *testing.T) { - internal.RunTest(t, func(api chi.Router, driver storage.Driver) { - store, _, err := driver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - _, err = store.Initialize(context.Background()) - require.NoError(t, err) + expectedTx := core.ExpandTransaction( + core.NewTransaction().WithPostings( + core.NewPosting("world", "bank", "USD", big.NewInt(100)), + ), + nil, + ) - tx1 := core.NewTransaction(). - WithPostings(core.NewPosting("world", "alice", "USD", big.NewInt(100))). - WithReference("ref:23434656"). - WithMetadata(core.Metadata{ - "foo1": "bar1", - }). - WithTimestamp(core.Now().Add(-3 * time.Minute)) - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes(tx1))) - log := core.NewTransactionLog(tx1, nil) - require.NoError(t, store.AppendLog(context.Background(), &log)) + backend, mockLedger := newTestingBackend(t) + mockLedger. + EXPECT(). + RevertTransaction(gomock.Any(), uint64(0)). + Return(&expectedTx, nil) - tx2 := core.NewTransaction(). - WithPostings(core.NewPosting("world", "bob", "USD", big.NewInt(100))). - WithReference("ref:534646"). - WithMetadata(core.Metadata{ - "foo2": "bar2", - }). - WithID(1). - WithTimestamp(core.Now().Add(-2 * time.Minute)) - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes(tx2))) - log2 := core.NewTransactionLog(tx2, nil) - require.NoError(t, store.AppendLog(context.Background(), &log2)) + router := routes.NewRouter(backend, nil, nil) - tx3 := core.NewTransaction(). - WithPostings(core.NewPosting("alice", "bob", "USD", big.NewInt(3))). - WithMetadata(core.Metadata{ - "foo2": "bar2", - }). - WithID(2). - WithTimestamp(core.Now().Add(-time.Minute)) - require.NoError(t, store.InsertTransactions(context.Background(), core.ExpandTransactionFromEmptyPreCommitVolumes(tx3))) - log3 := core.NewTransactionLog(tx3, nil) - require.NoError(t, store.AppendLog(context.Background(), &log3)) + req := httptest.NewRequest(http.MethodPost, "/xxx/transactions/0/revert", nil) + rec := httptest.NewRecorder() - require.NoError(t, store.EnsureAccountExists(context.Background(), "world")) - require.NoError(t, store.EnsureAccountExists(context.Background(), "bob")) - require.NoError(t, store.EnsureAccountExists(context.Background(), "alice")) - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "world": { - "USD": core.NewEmptyVolumes().WithOutput(big.NewInt(200)), - }, - "alice": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(100)).WithOutput(big.NewInt(3)), - }, - "bob": { - "USD": core.NewEmptyVolumes().WithInput(big.NewInt(103)), - }, - })) - - t.Run("first revert should succeed", func(t *testing.T) { - rsp := internal.RevertTransaction(api, 2) - require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - - res, _ := internal.DecodeSingleResponse[core.ExpandedTransaction](t, rsp.Body) - require.EqualValues(t, 3, res.ID) - require.Equal(t, core.Metadata{ - core.RevertMetadataSpecKey(): "2", - }, res.Metadata) - }) - - t.Run("transaction not found", func(t *testing.T) { - rsp := internal.RevertTransaction(api, uint64(42)) - require.Equal(t, http.StatusNotFound, rsp.Result().StatusCode, rsp.Body.String()) - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrNotFound, - ErrorMessage: "transaction 42 not found", - }, err) - }) - - //TODO(gfyrag): tests MUST not depends on previous tests - //use a table driven test - t.Run("second revert should fail", func(t *testing.T) { - require.NoError(t, store.UpdateTransactionMetadata(context.Background(), 2, core.RevertedMetadata(3))) - - rsp := internal.RevertTransaction(api, 2) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "transaction 2 already reverted", - }, err) - }) - - t.Run("invalid transaction ID format", func(t *testing.T) { - rsp := internal.NewPostOnLedger(t, api, "/transactions/invalid/revert", nil) - require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String()) - - err := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &err) - require.EqualValues(t, sharedapi.ErrorResponse{ - ErrorCode: apierrors.ErrValidation, - ErrorMessage: "invalid transaction ID", - }, err) - }) - }) -} - -func TestPostTransactionsScriptConflict(t *testing.T) { - internal.RunTest(t, func(api chi.Router, driver storage.Driver) { - store, _, err := driver.GetLedgerStore(context.Background(), internal.TestingLedger, true) - require.NoError(t, err) - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - log := core.NewTransactionLog( - core.NewTransaction(). - WithPostings(core.NewPosting("world", "centralbank", "COIN", big.NewInt(100))). - WithReference("1234"), - nil, - ) - require.NoError(t, store.AppendLog(context.Background(), &log)) - rsp := internal.PostTransaction(t, api, controllers.PostTransactionRequest{ - Script: core.Script{ - Plain: ` - send [COIN 100] ( - source = @world - destination = @centralbank - )`, - }, - Reference: "1234", - }, false) + router.ServeHTTP(rec, req) - require.Equal(t, http.StatusConflict, rsp.Result().StatusCode) - actualErr := sharedapi.ErrorResponse{} - internal.Decode(t, rsp.Body, &actualErr) - require.Equal(t, apierrors.ErrConflict, actualErr.ErrorCode) - }) + // TODO(gfyrag): Change to 201 + require.Equal(t, http.StatusOK, rec.Code) + tx, ok := DecodeSingleResponse[core.ExpandedTransaction](t, rec.Body) + require.True(t, ok) + require.Equal(t, expectedTx, tx) } diff --git a/pkg/api/controllers/utils_test.go b/pkg/api/controllers/utils_test.go new file mode 100644 index 000000000..142433ef5 --- /dev/null +++ b/pkg/api/controllers/utils_test.go @@ -0,0 +1,57 @@ +package controllers_test + +import ( + "bytes" + "encoding/json" + "io" + "testing" + + sharedapi "github.com/formancehq/stack/libs/go-libs/api" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func Encode(t *testing.T, v interface{}) []byte { + data, err := json.Marshal(v) + assert.NoError(t, err) + return data +} + +func Buffer(t *testing.T, v interface{}) *bytes.Buffer { + return bytes.NewBuffer(Encode(t, v)) +} + +func Decode(t *testing.T, reader io.Reader, v interface{}) bool { + err := json.NewDecoder(reader).Decode(v) + return assert.NoError(t, err) +} + +func DecodeSingleResponse[T any](t *testing.T, reader io.Reader) (T, bool) { + res := sharedapi.BaseResponse[T]{} + if !Decode(t, reader, &res) { + var zero T + return zero, false + } + return *res.Data, true +} + +func DecodeCursorResponse[T any](t *testing.T, reader io.Reader) *sharedapi.Cursor[T] { + res := sharedapi.BaseResponse[T]{} + Decode(t, reader, &res) + return res.Cursor +} + +func newTestingBackend(t *testing.T) (*MockBackend, *MockLedger) { + ctrl := gomock.NewController(t) + mockLedger := NewMockLedger(ctrl) + backend := NewMockBackend(ctrl) + backend. + EXPECT(). + GetLedger(gomock.Any(), gomock.Any()). + MinTimes(0). + Return(mockLedger, nil) + t.Cleanup(func() { + ctrl.Finish() + }) + return backend, mockLedger +} diff --git a/pkg/api/internal/testing.go b/pkg/api/internal/testing.go deleted file mode 100644 index 4d05c88e4..000000000 --- a/pkg/api/internal/testing.go +++ /dev/null @@ -1,221 +0,0 @@ -package internal - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/api/routes" - "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledger" - "github.com/formancehq/ledger/pkg/ledger/lock" - "github.com/formancehq/ledger/pkg/ledger/monitor" - "github.com/formancehq/ledger/pkg/ledgertesting" - "github.com/formancehq/ledger/pkg/storage" - sharedapi "github.com/formancehq/stack/libs/go-libs/api" - "github.com/formancehq/stack/libs/go-libs/health" - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/go-chi/chi/v5" - "github.com/pborman/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var TestingLedger string - -func Encode(t *testing.T, v interface{}) []byte { - data, err := json.Marshal(v) - assert.NoError(t, err) - return data -} - -func Buffer(t *testing.T, v interface{}) *bytes.Buffer { - return bytes.NewBuffer(Encode(t, v)) -} - -func Decode(t *testing.T, reader io.Reader, v interface{}) bool { - err := json.NewDecoder(reader).Decode(v) - return assert.NoError(t, err) -} - -func DecodeSingleResponse[T any](t *testing.T, reader io.Reader) (T, bool) { - res := sharedapi.BaseResponse[T]{} - if !Decode(t, reader, &res) { - var zero T - return zero, false - } - return *res.Data, true -} - -func DecodeCursorResponse[T any](t *testing.T, reader io.Reader) *sharedapi.Cursor[T] { - res := sharedapi.BaseResponse[T]{} - Decode(t, reader, &res) - return res.Cursor -} - -func NewRequest(method, path string, body io.Reader) (*http.Request, *httptest.ResponseRecorder) { - rec := httptest.NewRecorder() - req := httptest.NewRequest(method, path, body) - req.Header.Set("Content-Type", "application/json") - - return req, rec -} - -func PostTransaction(t *testing.T, handler http.Handler, payload controllers.PostTransactionRequest, preview bool) *httptest.ResponseRecorder { - path := fmt.Sprintf("/%s/transactions", TestingLedger) - if preview { - path += "?preview=true" - } - req, rec := NewRequest(http.MethodPost, path, Buffer(t, payload)) - handler.ServeHTTP(rec, req) - return rec -} - -func PostTransactionMetadata(t *testing.T, handler http.Handler, id uint64, m core.Metadata) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodPost, fmt.Sprintf("/%s/transactions/%d/metadata", TestingLedger, id), Buffer(t, m)) - handler.ServeHTTP(rec, req) - return rec -} - -func CountTransactions(handler http.Handler, query url.Values) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodHead, fmt.Sprintf("/%s/transactions", TestingLedger), nil) - req.URL.RawQuery = query.Encode() - handler.ServeHTTP(rec, req) - return rec -} - -func GetTransactions(handler http.Handler, query url.Values) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/transactions", TestingLedger), nil) - req.URL.RawQuery = query.Encode() - handler.ServeHTTP(rec, req) - return rec -} - -func GetTransaction(handler http.Handler, id uint64) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/transactions/%d", TestingLedger, id), nil) - handler.ServeHTTP(rec, req) - return rec -} - -func RevertTransaction(handler http.Handler, id uint64) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodPost, fmt.Sprintf("/"+TestingLedger+"/transactions/%d/revert", id), nil) - handler.ServeHTTP(rec, req) - return rec -} - -func CountAccounts(handler http.Handler, query url.Values) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodHead, fmt.Sprintf("/%s/accounts", TestingLedger), nil) - req.URL.RawQuery = query.Encode() - handler.ServeHTTP(rec, req) - return rec -} - -func GetAccounts(handler http.Handler, query url.Values) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/accounts", TestingLedger), nil) - req.URL.RawQuery = query.Encode() - handler.ServeHTTP(rec, req) - return rec -} - -func GetBalances(handler http.Handler, query url.Values) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/balances", TestingLedger), nil) - req.URL.RawQuery = query.Encode() - handler.ServeHTTP(rec, req) - return rec -} - -func GetBalancesAggregated(handler http.Handler, query url.Values) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/aggregate/balances", TestingLedger), nil) - req.URL.RawQuery = query.Encode() - handler.ServeHTTP(rec, req) - return rec -} - -func GetAccount(handler http.Handler, addr string) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/accounts/%s", TestingLedger, addr), nil) - handler.ServeHTTP(rec, req) - return rec -} - -func PostAccountMetadata(t *testing.T, handler http.Handler, addr string, m core.Metadata) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodPost, fmt.Sprintf("/%s/accounts/%s/metadata", TestingLedger, addr), Buffer(t, m)) - handler.ServeHTTP(rec, req) - return rec -} - -func NewRequestOnLedger(t *testing.T, handler http.Handler, path string, body any) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodPost, fmt.Sprintf("/%s%s", TestingLedger, path), Buffer(t, body)) - handler.ServeHTTP(rec, req) - return rec -} - -func NewGetOnLedger(handler http.Handler, path string) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s%s", TestingLedger, path), nil) - handler.ServeHTTP(rec, req) - return rec -} - -func NewPostOnLedger(t *testing.T, handler http.Handler, path string, body any) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodPost, fmt.Sprintf("/%s%s", TestingLedger, path), Buffer(t, body)) - handler.ServeHTTP(rec, req) - return rec -} - -func GetLedgerInfo(handler http.Handler) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/_info", TestingLedger), nil) - handler.ServeHTTP(rec, req) - return rec -} - -func GetLedgerStats(handler http.Handler) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/stats", TestingLedger), nil) - handler.ServeHTTP(rec, req) - return rec -} - -func GetLedgerLogs(handler http.Handler, query url.Values) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, fmt.Sprintf("/%s/logs", TestingLedger), nil) - req.URL.RawQuery = query.Encode() - handler.ServeHTTP(rec, req) - return rec -} - -func GetInfo(handler http.Handler) *httptest.ResponseRecorder { - req, rec := NewRequest(http.MethodGet, "/_info", nil) - handler.ServeHTTP(rec, req) - return rec -} - -func GetLedgerStore(t *testing.T, driver storage.Driver, ctx context.Context) storage.LedgerStore { - store, _, err := driver.GetLedgerStore(ctx, TestingLedger, true) - require.NoError(t, err) - return store -} - -func RunTest(t *testing.T, callback func(api chi.Router, storageDriver storage.Driver)) { - TestingLedger = uuid.New() - t.Parallel() - - storageDriver := ledgertesting.StorageDriver(t) - require.NoError(t, storageDriver.Initialize(context.Background())) - - ledgerStore, _, err := storageDriver.GetLedgerStore(context.Background(), uuid.New(), true) - require.NoError(t, err) - - modified, err := ledgerStore.Initialize(context.Background()) - require.NoError(t, err) - require.True(t, modified) - - resolver := ledger.NewResolver(storageDriver, monitor.NewNoOpMonitor(), lock.NewInMemory(), false) - router := routes.NewRouter(storageDriver, "latest", resolver, - logging.FromContext(context.Background()), &health.HealthController{}) - - callback(router, storageDriver) -} diff --git a/pkg/api/middlewares/ledger_middleware.go b/pkg/api/middlewares/ledger_middleware.go index 6140ec5c2..1ea827fca 100644 --- a/pkg/api/middlewares/ledger_middleware.go +++ b/pkg/api/middlewares/ledger_middleware.go @@ -5,7 +5,6 @@ import ( "github.com/formancehq/ledger/pkg/api/apierrors" "github.com/formancehq/ledger/pkg/api/controllers" - "github.com/formancehq/ledger/pkg/ledger" "github.com/formancehq/ledger/pkg/opentelemetry" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/go-chi/chi/v5" @@ -13,7 +12,7 @@ import ( "go.opentelemetry.io/otel/trace" ) -func LedgerMiddleware(resolver *ledger.Resolver) func(handler http.Handler) http.Handler { +func LedgerMiddleware(resolver controllers.Backend) func(handler http.Handler) http.Handler { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "ledger") diff --git a/pkg/api/routes/routes.go b/pkg/api/routes/routes.go index 83b0c7bfa..6d25c44df 100644 --- a/pkg/api/routes/routes.go +++ b/pkg/api/routes/routes.go @@ -5,8 +5,6 @@ import ( "github.com/formancehq/ledger/pkg/api/controllers" "github.com/formancehq/ledger/pkg/api/middlewares" - "github.com/formancehq/ledger/pkg/ledger" - "github.com/formancehq/ledger/pkg/storage" "github.com/formancehq/stack/libs/go-libs/health" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/go-chi/chi/v5" @@ -15,8 +13,7 @@ import ( "github.com/riandyrn/otelchi" ) -func NewRouter(storageDriver storage.Driver, version string, resolver *ledger.Resolver, - logger logging.Logger, healthController *health.HealthController) chi.Router { +func NewRouter(backend controllers.Backend, logger logging.Logger, healthController *health.HealthController) chi.Router { router := chi.NewMux() router.Use( @@ -28,22 +25,24 @@ func NewRouter(storageDriver storage.Driver, version string, resolver *ledger.Re }).Handler, func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handler.ServeHTTP(w, r.WithContext( - logging.ContextWithLogger(r.Context(), logger), - )) + if logger != nil { + r = r.WithContext( + logging.ContextWithLogger(r.Context(), logger), + ) + } + handler.ServeHTTP(w, r) }) }, middlewares.Log(), middleware.Recoverer, ) - router.Use() router.Use(middlewares.Log()) router.Get("/_healthcheck", healthController.Check) router.Group(func(router chi.Router) { router.Use(otelchi.Middleware("ledger")) - router.Get("/_info", controllers.GetInfo(storageDriver, version)) + router.Get("/_info", controllers.GetInfo(backend)) router.Route("/{ledger}", func(router chi.Router) { router.Use(func(handler http.Handler) http.Handler { @@ -51,7 +50,7 @@ func NewRouter(storageDriver storage.Driver, version string, resolver *ledger.Re handler.ServeHTTP(w, r) }) }) - router.Use(middlewares.LedgerMiddleware(resolver)) + router.Use(middlewares.LedgerMiddleware(backend)) // LedgerController router.Get("/_info", controllers.GetLedgerInfo) diff --git a/pkg/core/account.go b/pkg/core/account.go index 41c4f237f..d6ca889d0 100644 --- a/pkg/core/account.go +++ b/pkg/core/account.go @@ -31,6 +31,16 @@ type AccountWithVolumes struct { Volumes AssetsVolumes `json:"volumes"` } +func NewAccountWithVolumes(address string) *AccountWithVolumes { + return &AccountWithVolumes{ + Account: Account{ + Address: address, + Metadata: Metadata{}, + }, + Volumes: map[string]Volumes{}, + } +} + func (v AccountWithVolumes) MarshalJSON() ([]byte, error) { type aux AccountWithVolumes return json.Marshal(struct { diff --git a/pkg/core/log.go b/pkg/core/log.go index ca9cb17f0..4ddc2d5be 100644 --- a/pkg/core/log.go +++ b/pkg/core/log.go @@ -54,8 +54,8 @@ func (l Log) WithReference(reference string) Log { } type NewTransactionLogPayload struct { - Transaction Transaction - AccountMetadata map[string]Metadata + Transaction Transaction `json:"transaction"` + AccountMetadata map[string]Metadata `json:"accountMetadata"` } func NewTransactionLogWithDate(tx Transaction, accountMetadata map[string]Metadata, time Time) Log { diff --git a/pkg/core/numscript.go b/pkg/core/numscript.go index 67df26603..4041ad176 100644 --- a/pkg/core/numscript.go +++ b/pkg/core/numscript.go @@ -103,6 +103,10 @@ func TxToScriptData(txData TransactionData) RunScript { vars[v.name] = v.jsonVal } + if txData.Metadata == nil { + txData.Metadata = Metadata{} + } + return RunScript{ Script: Script{ Plain: sb.String(), diff --git a/pkg/ledger/ledger.go b/pkg/ledger/ledger.go index 5db912e05..056a8275c 100644 --- a/pkg/ledger/ledger.go +++ b/pkg/ledger/ledger.go @@ -55,7 +55,7 @@ func (l *Ledger) GetLedgerStore() storage.LedgerStore { } func (l *Ledger) writeLog(ctx context.Context, logHolder *core.LogHolder) error { - l.queryWorker.QueueLog(ctx, logHolder, l.store) + l.queryWorker.QueueLog(logHolder) // Wait for CQRS ingestion // TODO(polo/gfyrag): add possiblity to disable this via request param select { @@ -174,6 +174,7 @@ func (l *Ledger) SaveMeta(ctx context.Context, targetType string, targetID inter at := core.Now() var ( + err error log core.Log ) switch targetType { @@ -213,21 +214,22 @@ func (l *Ledger) SaveMeta(ctx context.Context, targetType string, targetID inter default: return errorsutil.NewError(ErrValidation, errors.Errorf("unknown target type '%s'", targetType)) } - - err := l.store.AppendLog(ctx, &log) if err != nil { - return errors.Wrap(err, "append log") + return err } + err = l.store.AppendLog(ctx, &log) logHolder := core.NewLogHolder(&log) - if err := l.writeLog(ctx, logHolder); err != nil { - return errors.Wrap(err, "write log") + if err == nil { + if err := l.writeLog(ctx, logHolder); err != nil { + return err + } } - return nil + return err } -func (l *Ledger) GetLogs(ctx context.Context, q *storage.LogsQuery) (api.Cursor[core.Log], error) { - logs, err := l.store.GetLogs(ctx, q) +func (l *Ledger) GetLogs(ctx context.Context, q storage.LogsQuery) (api.Cursor[core.Log], error) { + logs, err := l.store.GetLogs(ctx, &q) return logs, errors.Wrap(err, "getting logs") } diff --git a/pkg/ledger/ledger_test.go b/pkg/ledger/ledger_test.go deleted file mode 100644 index 13ddd4424..000000000 --- a/pkg/ledger/ledger_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package ledger - -import ( - "context" - "fmt" - "math/big" - "sync" - "testing" - - "github.com/formancehq/ledger/pkg/core" - "github.com/google/uuid" - "github.com/stretchr/testify/require" -) - -func TestAccountMetadata(t *testing.T) { - runOnLedger(t, func(l *Ledger) { - - err := l.SaveMeta(context.Background(), core.MetaTargetTypeAccount, "users:001", core.Metadata{ - "a random metadata": "old value", - }) - require.NoError(t, err) - - err = l.SaveMeta(context.Background(), core.MetaTargetTypeAccount, "users:001", core.Metadata{ - "a random metadata": "new value", - }) - require.NoError(t, err) - - acc, err := l.dbCache.GetAccountWithVolumes(context.Background(), "users:001") - require.NoError(t, err) - - meta, ok := acc.Metadata["a random metadata"] - require.True(t, ok) - - require.Equalf(t, meta, "new value", - "metadata entry did not match in get: expected \"new value\", got %v", meta) - - // We have to create at least one transaction to retrieve an account from GetAccounts store method - _, err = l.CreateTransaction(context.Background(), false, core.TxToScriptData(core.TransactionData{ - Postings: core.Postings{ - { - Source: "world", - Amount: big.NewInt(100), - Asset: "USD", - Destination: "users:001", - }, - }, - })) - require.NoError(t, err) - - acc, err = l.dbCache.GetAccountWithVolumes(context.Background(), "users:001") - require.NoError(t, err) - require.NotNil(t, acc) - - meta, ok = acc.Metadata["a random metadata"] - require.True(t, ok) - require.Equalf(t, meta, "new value", - "metadata entry did not match in find: expected \"new value\", got %v", meta) - }) -} - -func TestTransactionMetadata(t *testing.T) { - runOnLedger(t, func(l *Ledger) { - err := l.SaveMeta(context.Background(), core.MetaTargetTypeTransaction, uint64(0), core.Metadata{ - "a random metadata": "old value", - }) - require.NoError(t, err) - }) -} - -func TestRevertTransaction(t *testing.T) { - runOnLedger(t, func(l *Ledger) { - tx := core.Transaction{ - TransactionData: core.TransactionData{ - Reference: "foo", - Postings: []core.Posting{ - core.NewPosting("world", "payments:001", "COIN", big.NewInt(100)), - }, - }, - } - expandedTx := core.ExpandedTransaction{ - Transaction: tx, - PreCommitVolumes: map[string]core.AssetsVolumes{ - "world": { - "COIN": core.NewEmptyVolumes().WithOutput(big.NewInt(10)), - }, - "payments:001": { - "COIN": core.NewEmptyVolumes(), - }, - }, - PostCommitVolumes: map[string]core.AssetsVolumes{ - "world": { - "COIN": core.NewEmptyVolumes().WithOutput(big.NewInt(110)), - }, - "payments:001": { - "COIN": core.NewEmptyVolumes().WithInput(big.NewInt(100)), - }, - }, - } - - require.NoError(t, l.GetLedgerStore().InsertTransactions(context.Background(), expandedTx)) - require.NoError(t, l.GetLedgerStore().EnsureAccountExists(context.Background(), "payments:001")) - require.NoError(t, l.GetLedgerStore().UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "payments:001": { - "COIN": core.NewEmptyVolumes().WithInput(big.NewInt(110)), - }, - "world": { - "COIN": core.NewEmptyVolumes().WithOutput(big.NewInt(110)), - }, - })) - - revertTx, err := l.RevertTransaction(context.Background(), tx.ID) - require.NoError(t, err) - - require.Equal(t, core.Postings{ - { - Source: "payments:001", - Destination: "world", - Amount: big.NewInt(100), - Asset: "COIN", - }, - }, revertTx.TransactionData.Postings) - - require.EqualValues(t, fmt.Sprintf("%d", tx.ID), revertTx.Metadata[core.RevertMetadataSpecKey()]) - require.Equal(t, revertTx.Timestamp, l.runner.GetState().GetMoreRecentTransactionDate()) - - account, err := l.dbCache.GetAccountWithVolumes(context.Background(), "payments:001") - require.NoError(t, err) - require.Equal(t, core.AccountWithVolumes{ - Account: core.Account{ - Address: "payments:001", - Metadata: core.Metadata{}, - }, - Volumes: core.AssetsVolumes{ - "COIN": core.NewEmptyVolumes(). - WithInput(big.NewInt(110)). - WithOutput(tx.Postings[0].Amount), - }, - }, *account) - - rawLogs, err := l.GetLedgerStore().ReadLogsStartingFromID(context.Background(), 0) - require.NoError(t, err) - require.Len(t, rawLogs, 1) - require.Equal(t, core.NewRevertedTransactionLog(revertTx.Timestamp, tx.ID, revertTx.Transaction). - WithReference("revert_"+tx.Reference). - ComputeHash(nil), rawLogs[0]) - }) -} - -func TestVeryBigTransaction(t *testing.T) { - runOnLedger(t, func(l *Ledger) { - amount, ok := new(big.Int).SetString( - "199999999999999999992919191919192929292939847477171818284637291884661818183647392936472918836161728274766266161728493736383838", 10) - require.True(t, ok) - - _, err := l.CreateTransaction(context.Background(), false, - core.TxToScriptData(core.TransactionData{ - Postings: []core.Posting{{ - Source: "world", - Destination: "bank", - Asset: "ETH/18", - Amount: amount, - }}, - })) - require.NoError(t, err) - }) -} - -func BenchmarkSequentialWrites(b *testing.B) { - ledgerName := uuid.NewString() - resolver := newResolver(b) - - ledger, err := resolver.GetLedger(context.TODO(), ledgerName) - require.NoError(b, err) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := ledger.CreateTransaction(context.Background(), false, core.RunScript{ - Script: core.Script{ - Plain: `send [USD/2 100] ( - source = @world - destination = @bank - )`, - }, - }) - require.NoError(b, err) - } -} - -func BenchmarkParallelWrites(b *testing.B) { - resolver := newResolver(b) - - ledger, err := resolver.GetLedger(context.Background(), uuid.NewString()) - require.NoError(b, err) - - b.ResetTimer() - wg := sync.WaitGroup{} - wg.Add(b.N) - for i := 0; i < b.N; i++ { - go func() { - defer wg.Done() - - _, err := ledger.CreateTransaction(context.Background(), false, core.RunScript{ - Script: core.Script{ - Plain: `send [USD/2 100] ( - source = @world - destination = @bank - )`, - }, - }) - require.NoError(b, err) - }() - } - wg.Wait() -} diff --git a/pkg/ledger/main_test.go b/pkg/ledger/main_test.go deleted file mode 100644 index 29bf6ff83..000000000 --- a/pkg/ledger/main_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package ledger - -import ( - "context" - "os" - "testing" - - "github.com/formancehq/ledger/pkg/ledger/lock" - "github.com/formancehq/ledger/pkg/ledger/monitor" - "github.com/formancehq/ledger/pkg/ledgertesting" - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/formancehq/stack/libs/go-libs/pgtesting" - "github.com/pborman/uuid" - "github.com/stretchr/testify/require" -) - -func TestMain(t *testing.M) { - if err := pgtesting.CreatePostgresServer(); err != nil { - logging.Error(err) - os.Exit(1) - } - code := t.Run() - if err := pgtesting.DestroyPostgresServer(); err != nil { - logging.Error(err) - } - os.Exit(code) -} - -func newResolver(t interface{ pgtesting.TestingT }) *Resolver { - storageDriver := ledgertesting.StorageDriver(t) - require.NoError(t, storageDriver.Initialize(context.Background())) - - ledgerStore, _, err := storageDriver.GetLedgerStore(context.Background(), uuid.New(), true) - require.NoError(t, err) - - modified, err := ledgerStore.Initialize(context.Background()) - require.NoError(t, err) - require.True(t, modified) - - return NewResolver(storageDriver, monitor.NewNoOpMonitor(), lock.NewInMemory(), false) -} - -func runOnLedger(t interface { - pgtesting.TestingT - Parallel() -}, f func(l *Ledger)) { - t.Parallel() - - ledgerName := uuid.New() - resolver := newResolver(t) - l, err := resolver.GetLedger(context.Background(), ledgerName) - require.NoError(t, err) - defer l.Close(context.Background()) - - f(l) -} diff --git a/pkg/ledger/query/store.go b/pkg/ledger/query/store.go new file mode 100644 index 000000000..14b50ea34 --- /dev/null +++ b/pkg/ledger/query/store.go @@ -0,0 +1,41 @@ +package query + +import ( + "context" + + "github.com/formancehq/ledger/pkg/core" + "github.com/formancehq/ledger/pkg/storage" +) + +type Store interface { + UpdateNextLogID(ctx context.Context, u uint64) error + IsInitialized() bool + GetNextLogID(ctx context.Context) (uint64, error) + ReadLogsStartingFromID(ctx context.Context, id uint64) ([]core.Log, error) + RunInTransaction(ctx context.Context, f func(ctx context.Context, tx Store) error) error + GetAccountWithVolumes(ctx context.Context, address string) (*core.AccountWithVolumes, error) + GetTransaction(ctx context.Context, id uint64) (*core.ExpandedTransaction, error) + UpdateAccountsMetadata(ctx context.Context, update []core.Account) error + InsertTransactions(ctx context.Context, insert ...core.ExpandedTransaction) error + UpdateTransactionsMetadata(ctx context.Context, update ...core.TransactionWithMetadata) error + EnsureAccountsExist(ctx context.Context, accounts []string) error + UpdateVolumes(ctx context.Context, update ...core.AccountsAssetsVolumes) error +} + +type defaultStore struct { + storage.LedgerStore +} + +func (d defaultStore) RunInTransaction(ctx context.Context, f func(ctx context.Context, tx Store) error) error { + return d.LedgerStore.RunInTransaction(ctx, func(ctx context.Context, store storage.LedgerStore) error { + return f(ctx, NewDefaultStore(d.LedgerStore)) + }) +} + +var _ Store = (*defaultStore)(nil) + +func NewDefaultStore(underlying storage.LedgerStore) *defaultStore { + return &defaultStore{ + LedgerStore: underlying, + } +} diff --git a/pkg/ledger/query/worker.go b/pkg/ledger/query/worker.go index 1bc6dad66..4822d9107 100644 --- a/pkg/ledger/query/worker.go +++ b/pkg/ledger/query/worker.go @@ -15,7 +15,7 @@ import ( var ( DefaultWorkerConfig = WorkerConfig{ - ChanSize: 100, + ChanSize: 1024, } ) @@ -42,10 +42,16 @@ type Worker struct { releasedJob chan struct{} errorChan chan error stopChan chan chan struct{} + readyChan chan struct{} - store storage.LedgerStore + store Store monitor monitor.Monitor lastProcessedLogID *uint64 + ledgerName string +} + +func (w *Worker) Ready() chan struct{} { + return w.readyChan } func (w *Worker) Run(ctx context.Context) error { @@ -53,6 +59,8 @@ func (w *Worker) Run(ctx context.Context) error { w.ctx = ctx + close(w.readyChan) + for { select { case <-w.ctx.Done(): @@ -113,7 +121,7 @@ func (w *Worker) writeLoop(ctx context.Context) { logging.FromContext(w.ctx).Errorf("CQRS worker error: %s", err) closeLogs(modelsHolder) - // TODO(polo/gfyrag): add indempotency tests + // TODO(polo/gfyrag): add idempotency tests // Return the error to restart the worker w.errorChan <- err return @@ -278,7 +286,12 @@ func (w *Worker) processLogs(ctx context.Context, logs ...core.Log) error { return errors.Wrap(err, "building data") } - if err := w.store.RunInTransaction(ctx, func(ctx context.Context, tx storage.LedgerStore) error { + if err := w.store.RunInTransaction(ctx, func(ctx context.Context, tx Store) error { + if len(logsData.ensureAccountsExist) > 0 { + if err := tx.EnsureAccountsExist(ctx, logsData.ensureAccountsExist); err != nil { + return errors.Wrap(err, "ensuring accounts exist") + } + } if len(logsData.accountsToUpdate) > 0 { if err := tx.UpdateAccountsMetadata(ctx, logsData.accountsToUpdate); err != nil { return errors.Wrap(err, "updating accounts metadata") @@ -297,12 +310,6 @@ func (w *Worker) processLogs(ctx context.Context, logs ...core.Log) error { } } - if len(logsData.ensureAccountsExist) > 0 { - if err := tx.EnsureAccountsExist(ctx, logsData.ensureAccountsExist); err != nil { - return errors.Wrap(err, "ensuring accounts exist") - } - } - if len(logsData.volumesToUpdate) > 0 { return tx.UpdateVolumes(ctx, logsData.volumesToUpdate...) } @@ -330,6 +337,7 @@ func (w *Worker) buildData( volumeAggregator := aggregator.Volumes(w.store) accountsToUpdate := make(map[string]core.Metadata) transactionsToUpdate := make(map[uint64]core.Metadata) + for _, log := range logs { switch log.Type { case core.NewTransactionLogType: @@ -366,9 +374,9 @@ func (w *Worker) buildData( logsData.volumesToUpdate = append(logsData.volumesToUpdate, txVolumeAggregator.PostCommitVolumes) logsData.monitors = append(logsData.monitors, func(ctx context.Context, monitor monitor.Monitor) { - w.monitor.CommittedTransactions(ctx, w.store.Name(), expandedTx) + w.monitor.CommittedTransactions(ctx, w.ledgerName, expandedTx) for account, metadata := range payload.AccountMetadata { - w.monitor.SavedMetadata(ctx, w.store.Name(), core.MetaTargetTypeAccount, account, metadata) + w.monitor.SavedMetadata(ctx, w.ledgerName, core.MetaTargetTypeAccount, account, metadata) } }) @@ -397,7 +405,7 @@ func (w *Worker) buildData( } logsData.monitors = append(logsData.monitors, func(ctx context.Context, monitor monitor.Monitor) { - w.monitor.SavedMetadata(ctx, w.store.Name(), w.store.Name(), fmt.Sprint(setMetadata.TargetID), setMetadata.Metadata) + w.monitor.SavedMetadata(ctx, w.ledgerName, setMetadata.TargetType, fmt.Sprint(setMetadata.TargetID), setMetadata.Metadata) }) case core.RevertedTransactionLogType: @@ -430,7 +438,7 @@ func (w *Worker) buildData( } logsData.monitors = append(logsData.monitors, func(ctx context.Context, monitor monitor.Monitor) { - w.monitor.RevertedTransaction(ctx, w.store.Name(), revertedTx, &expandedTx) + w.monitor.RevertedTransaction(ctx, w.ledgerName, revertedTx, &expandedTx) }) } } @@ -452,14 +460,14 @@ func (w *Worker) buildData( return logsData, nil } -func (w *Worker) QueueLog(ctx context.Context, log *core.LogHolder, store storage.LedgerStore) { +func (w *Worker) QueueLog(log *core.LogHolder) { select { case <-w.ctx.Done(): case w.writeChannel <- log: } } -func NewWorker(config WorkerConfig, store storage.LedgerStore, monitor monitor.Monitor) *Worker { +func NewWorker(config WorkerConfig, store Store, ledgerName string, monitor monitor.Monitor) *Worker { return &Worker{ pending: make([]*core.LogHolder, 0), jobs: make(chan []*core.LogHolder), @@ -467,8 +475,10 @@ func NewWorker(config WorkerConfig, store storage.LedgerStore, monitor monitor.M writeChannel: make(chan *core.LogHolder, config.ChanSize), errorChan: make(chan error, 1), stopChan: make(chan chan struct{}), + readyChan: make(chan struct{}), WorkerConfig: config, store: store, monitor: monitor, + ledgerName: ledgerName, } } diff --git a/pkg/ledger/query/worker_test.go b/pkg/ledger/query/worker_test.go index 521065006..1c6e62dc3 100644 --- a/pkg/ledger/query/worker_test.go +++ b/pkg/ledger/query/worker_test.go @@ -8,69 +8,141 @@ import ( "github.com/formancehq/ledger/pkg/core" "github.com/formancehq/ledger/pkg/ledger/monitor" - "github.com/formancehq/ledger/pkg/ledgertesting" - "github.com/formancehq/ledger/pkg/storage" - "github.com/formancehq/stack/libs/go-libs/pgtesting" - "github.com/pborman/uuid" "github.com/stretchr/testify/require" ) -func TestWorker(t *testing.T) { - t.Parallel() +type mockStore struct { + nextLogID uint64 + accounts map[string]*core.AccountWithVolumes + transactions []*core.ExpandedTransaction +} - require.NoError(t, pgtesting.CreatePostgresServer()) - defer func() { - require.NoError(t, pgtesting.DestroyPostgresServer()) - }() +func (m *mockStore) UpdateAccountsMetadata(ctx context.Context, update []core.Account) error { + for _, account := range update { + persistedAccount, ok := m.accounts[account.Address] + if !ok { + m.accounts[account.Address] = &core.AccountWithVolumes{ + Account: account, + Volumes: map[string]core.Volumes{}, + } + return nil + } + persistedAccount.Metadata = persistedAccount.Metadata.Merge(account.Metadata) + } + return nil +} + +func (m *mockStore) InsertTransactions(ctx context.Context, insert ...core.ExpandedTransaction) error { + for _, transaction := range insert { + m.transactions = append(m.transactions, &transaction) + } + return nil +} + +func (m *mockStore) UpdateTransactionsMetadata(ctx context.Context, update ...core.TransactionWithMetadata) error { + for _, tx := range update { + m.transactions[tx.ID].Metadata = m.transactions[tx.ID].Metadata.Merge(tx.Metadata) + } + return nil +} + +func (m *mockStore) EnsureAccountsExist(ctx context.Context, accounts []string) error { + for _, address := range accounts { + _, ok := m.accounts[address] + if ok { + continue + } + m.accounts[address] = &core.AccountWithVolumes{ + Account: core.Account{ + Address: address, + Metadata: core.Metadata{}, + }, + Volumes: map[string]core.Volumes{}, + } + } + return nil +} + +func (m *mockStore) UpdateVolumes(ctx context.Context, updates ...core.AccountsAssetsVolumes) error { + for _, update := range updates { + for address, volumes := range update { + for asset, assetsVolumes := range volumes { + m.accounts[address].Volumes[asset] = assetsVolumes + } + } + } + return nil +} + +func (m *mockStore) UpdateNextLogID(ctx context.Context, id uint64) error { + m.nextLogID = id + return nil +} + +func (m *mockStore) IsInitialized() bool { + return true +} + +func (m *mockStore) GetNextLogID(ctx context.Context) (uint64, error) { + return m.nextLogID, nil +} + +func (m *mockStore) ReadLogsStartingFromID(ctx context.Context, id uint64) ([]core.Log, error) { + return []core.Log{}, nil +} + +func (m *mockStore) RunInTransaction(ctx context.Context, f func(ctx context.Context, tx Store) error) error { + return f(ctx, m) +} - driver := ledgertesting.StorageDriver(t) - require.NoError(t, driver.Initialize(context.Background())) +func (m *mockStore) GetAccountWithVolumes(ctx context.Context, address string) (*core.AccountWithVolumes, error) { + account, ok := m.accounts[address] + if !ok { + return &core.AccountWithVolumes{ + Account: core.Account{ + Address: address, + Metadata: core.Metadata{}, + }, + Volumes: map[string]core.Volumes{}, + }, nil + } + return account, nil +} - ledgerStore, _, err := driver.GetLedgerStore(context.Background(), uuid.New(), true) - require.NoError(t, err) +func (m *mockStore) GetTransaction(ctx context.Context, id uint64) (*core.ExpandedTransaction, error) { + return m.transactions[id], nil +} - modified, err := ledgerStore.Initialize(context.Background()) - require.NoError(t, err) - require.True(t, modified) +var _ Store = (*mockStore)(nil) + +func TestWorker(t *testing.T) { + t.Parallel() + + ledgerStore := &mockStore{ + accounts: map[string]*core.AccountWithVolumes{}, + } worker := NewWorker(WorkerConfig{ ChanSize: 1024, - }, ledgerStore, monitor.NewNoOpMonitor()) + }, ledgerStore, "default", monitor.NewNoOpMonitor()) go func() { require.NoError(t, worker.Run(context.Background())) }() defer func() { require.NoError(t, worker.Stop(context.Background())) }() + <-worker.Ready() var ( now = core.Now() ) - tx0 := core.Transaction{ - TransactionData: core.TransactionData{ - Postings: []core.Posting{{ - Source: "world", - Destination: "bank", - Amount: big.NewInt(100), - Asset: "USD/2", - }}, - Timestamp: now, - }, - ID: 0, - } - tx1 := core.Transaction{ - TransactionData: core.TransactionData{ - Postings: []core.Posting{{ - Source: "bank", - Destination: "user:1", - Amount: big.NewInt(10), - Asset: "USD/2", - }}, - Timestamp: now, - }, - ID: 1, - } + tx0 := core.NewTransaction().WithPostings( + core.NewPosting("world", "bank", "USD/2", big.NewInt(100)), + ) + tx1 := core.NewTransaction().WithPostings( + core.NewPosting("bank", "user:1", "USD/2", big.NewInt(10)), + ).WithID(1) appliedMetadataOnTX1 := core.Metadata{ "paymentID": "1234", @@ -79,10 +151,6 @@ func TestWorker(t *testing.T) { "category": "gold", } - nextLogID, err := ledgerStore.GetNextLogID(context.Background()) - require.True(t, storage.IsNotFoundError(err)) - require.Equal(t, uint64(0), nextLogID) - logs := []core.Log{ core.NewTransactionLog(tx0, nil), core.NewTransactionLog(tx1, nil), @@ -102,10 +170,10 @@ func TestWorker(t *testing.T) { Metadata: appliedMetadataOnAccount, }), } - for _, log := range logs { + for i, log := range logs { + log.ID = uint64(i) logHolder := core.NewLogHolder(&log) - require.NoError(t, ledgerStore.AppendLog(context.Background(), &log)) - worker.QueueLog(context.Background(), logHolder, ledgerStore) + worker.QueueLog(logHolder) <-logHolder.Ingested } require.Eventually(t, func() bool { @@ -114,23 +182,17 @@ func TestWorker(t *testing.T) { return nextLogID == uint64(len(logs)) }, time.Second, 100*time.Millisecond) - count, err := ledgerStore.CountTransactions(context.Background(), *storage.NewTransactionsQuery()) - require.NoError(t, err) - require.EqualValues(t, 2, count) - - count, err = ledgerStore.CountAccounts(context.Background(), *storage.NewAccountsQuery()) - require.NoError(t, err) - require.EqualValues(t, 4, count) + require.EqualValues(t, 2, len(ledgerStore.transactions)) + require.EqualValues(t, 4, len(ledgerStore.accounts)) - account, err := ledgerStore.GetAccountWithVolumes(context.Background(), "bank") - require.NoError(t, err) + account := ledgerStore.accounts["bank"] + require.NotNil(t, account) require.NotEmpty(t, account.Volumes) require.EqualValues(t, 100, account.Volumes["USD/2"].Input.Uint64()) require.EqualValues(t, 10, account.Volumes["USD/2"].Output.Uint64()) - tx1FromDatabase, err := ledgerStore.GetTransaction(context.Background(), 1) + tx1FromDatabase := ledgerStore.transactions[1] tx1.Metadata = appliedMetadataOnTX1 - require.NoError(t, err) require.Equal(t, core.ExpandedTransaction{ Transaction: tx1, PreCommitVolumes: map[string]core.AssetsVolumes{ @@ -163,8 +225,7 @@ func TestWorker(t *testing.T) { }, }, *tx1FromDatabase) - accountWithVolumes, err := ledgerStore.GetAccountWithVolumes(context.Background(), "bank") - require.NoError(t, err) + accountWithVolumes := ledgerStore.accounts["bank"] require.Equal(t, &core.AccountWithVolumes{ Account: core.Account{ Address: "bank", @@ -178,8 +239,7 @@ func TestWorker(t *testing.T) { }, }, accountWithVolumes) - accountWithVolumes, err = ledgerStore.GetAccountWithVolumes(context.Background(), "another:account") - require.NoError(t, err) + accountWithVolumes = ledgerStore.accounts["another:account"] require.Equal(t, &core.AccountWithVolumes{ Account: core.Account{ Address: "another:account", diff --git a/pkg/ledger/resolver.go b/pkg/ledger/resolver.go index 616c9c61b..39bc1f625 100644 --- a/pkg/ledger/resolver.go +++ b/pkg/ledger/resolver.go @@ -61,14 +61,12 @@ func (r *Resolver) GetLedger(ctx context.Context, name string) (*Ledger, error) } cache := cache.New(store) - runner, err := runner.New(store, r.locker, cache, r.compiler, r.allowPastTimestamps) + runner, err := runner.New(store, r.locker, cache, r.compiler, name, r.allowPastTimestamps) if err != nil { return nil, errors.Wrap(err, "creating ledger runner") } - queryWorker := query.NewWorker(query.WorkerConfig{ - ChanSize: 1024, - }, store, r.monitor) + queryWorker := query.NewWorker(query.DefaultWorkerConfig, query.NewDefaultStore(store), name, r.monitor) go func() { if err := queryWorker.Run(logging.ContextWithLogger( diff --git a/pkg/ledger/runner/runner.go b/pkg/ledger/runner/runner.go index 948896b6c..afdcabb50 100644 --- a/pkg/ledger/runner/runner.go +++ b/pkg/ledger/runner/runner.go @@ -18,16 +18,29 @@ import ( "github.com/pkg/errors" ) +type Store interface { + AppendLog(ctx context.Context, log *core.Log) error + ReadLastLogWithType(background context.Context, logType ...core.LogType) (*core.Log, error) + ReadLogWithReference(ctx context.Context, reference string) (*core.Log, error) +} + +type Cache interface { + GetAccountWithVolumes(ctx context.Context, address string) (*core.AccountWithVolumes, error) + LockAccounts(ctx context.Context, accounts ...string) (cache.Release, error) + UpdateVolumeWithTX(transaction core.Transaction) +} + type Runner struct { - store storage.LedgerStore + store Store // cache is used to store accounts - cache *cache.Cache + cache Cache // nextTxID store the next transaction id to be used nextTxID *atomic.Uint64 // locker is used to local a set of account - locker lock.Locker - compiler *numscript.Compiler - state *state.State + locker lock.Locker + compiler *numscript.Compiler + state *state.State + ledgerName string } type logComputer func(transaction core.ExpandedTransaction, accountMetadata map[string]core.Metadata) core.Log @@ -95,7 +108,7 @@ func (r *Runner) execute(ctx context.Context, script core.RunScript, logComputer return nil, nil, errors.Wrap(err, "locking accounts") } - unlock, err := r.locker.Lock(ctx, r.store.Name(), involvedAccounts...) + unlock, err := r.locker.Lock(ctx, r.ledgerName, involvedAccounts...) if err != nil { release() return nil, nil, errors.Wrap(err, "locking accounts") @@ -169,7 +182,7 @@ func (r *Runner) GetState() *state.State { return r.state } -func New(store storage.LedgerStore, locker lock.Locker, cache *cache.Cache, compiler *numscript.Compiler, allowPastTimestamps bool) (*Runner, error) { +func New(store Store, locker lock.Locker, cache Cache, compiler *numscript.Compiler, ledgerName string, allowPastTimestamps bool) (*Runner, error) { log, err := store.ReadLastLogWithType(context.Background(), core.NewTransactionLogType, core.RevertedTransactionLogType) if err != nil && !storage.IsNotFoundError(err) { return nil, err @@ -196,11 +209,12 @@ func New(store storage.LedgerStore, locker lock.Locker, cache *cache.Cache, comp nextTxID.Add(1) } return &Runner{ - state: state.New(store, allowPastTimestamps, lastTransactionDate), - store: store, - cache: cache, - locker: locker, - nextTxID: nextTxID, - compiler: compiler, + state: state.New(store, allowPastTimestamps, lastTransactionDate), + store: store, + cache: cache, + locker: locker, + nextTxID: nextTxID, + compiler: compiler, + ledgerName: ledgerName, }, nil } diff --git a/pkg/ledger/runner/runner_test.go b/pkg/ledger/runner/runner_test.go index 2bce4bee7..47026fc66 100644 --- a/pkg/ledger/runner/runner_test.go +++ b/pkg/ledger/runner/runner_test.go @@ -10,13 +10,76 @@ import ( "github.com/formancehq/ledger/pkg/ledger/lock" "github.com/formancehq/ledger/pkg/ledger/numscript" "github.com/formancehq/ledger/pkg/ledger/state" - "github.com/formancehq/ledger/pkg/ledgertesting" - "github.com/formancehq/stack/libs/go-libs/pgtesting" + "github.com/formancehq/ledger/pkg/storage" "github.com/google/uuid" "github.com/pkg/errors" "github.com/stretchr/testify/require" ) +type mockCache struct { + accounts map[string]*core.AccountWithVolumes +} + +func (m *mockCache) GetAccountWithVolumes(ctx context.Context, address string) (*core.AccountWithVolumes, error) { + account, ok := m.accounts[address] + if !ok { + account = core.NewAccountWithVolumes(address) + m.accounts[address] = account + return account, nil + } + return account, nil +} + +func (m *mockCache) LockAccounts(ctx context.Context, accounts ...string) (cache.Release, error) { + return func() {}, nil +} + +func (m *mockCache) UpdateVolumeWithTX(transaction core.Transaction) { + for _, posting := range transaction.Postings { + sourceAccount := m.accounts[posting.Source] + sourceAccountAsset := sourceAccount.Volumes[posting.Asset].CopyWithZerosIfNeeded() + sourceAccountAsset.Output = sourceAccountAsset.Output.Add(sourceAccountAsset.Output, posting.Amount) + sourceAccount.Volumes[posting.Asset] = sourceAccountAsset + destAccount := m.accounts[posting.Destination] + destAccountAsset := destAccount.Volumes[posting.Asset].CopyWithZerosIfNeeded() + destAccountAsset.Input = destAccountAsset.Input.Add(destAccountAsset.Input, posting.Amount) + destAccount.Volumes[posting.Asset] = destAccountAsset + } +} + +var _ Cache = (*mockCache)(nil) + +type mockStore struct { + logs []*core.Log +} + +func (m *mockStore) ReadLastLogWithType(background context.Context, logType ...core.LogType) (*core.Log, error) { + for _, log := range m.logs { + for _, logType := range logType { + if log.Type == logType { + return log, nil + } + } + } + return nil, storage.ErrNotFound +} + +func (m *mockStore) ReadLogWithReference(ctx context.Context, reference string) (*core.Log, error) { + for _, log := range m.logs { + if log.Reference == reference { + return log, nil + } + } + return nil, storage.ErrNotFound +} + +func (m *mockStore) AppendLog(ctx context.Context, log *core.Log) error { + m.logs = append(m.logs, log) + return nil +} + +var _ Store = (*mockStore)(nil) + type testCase struct { name string setup func(t *testing.T, r *Runner) @@ -159,30 +222,21 @@ func TestExecuteScript(t *testing.T) { t.Parallel() now := core.Now() - require.NoError(t, pgtesting.CreatePostgresServer()) - defer func() { - require.NoError(t, pgtesting.DestroyPostgresServer()) - }() - - storageDriver := ledgertesting.StorageDriver(t) - require.NoError(t, storageDriver.Initialize(context.Background())) - for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { ledger := uuid.NewString() - - store, _, err := storageDriver.GetLedgerStore(context.Background(), ledger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) + cache := &mockCache{ + accounts: map[string]*core.AccountWithVolumes{}, + } + store := &mockStore{ + logs: []*core.Log{}, + } compiler := numscript.NewCompiler() - cache := cache.New(store) - runner, err := New(store, lock.NewInMemory(), cache, compiler, false) + runner, err := New(store, lock.NewInMemory(), cache, compiler, ledger, false) require.NoError(t, err) if tc.setup != nil { @@ -206,14 +260,8 @@ func TestExecuteScript(t *testing.T) { tc.expectedTx.Timestamp = now require.Equal(t, tc.expectedTx, *ret) - logs, err := store.ReadLogsStartingFromID(context.Background(), 0) - require.NoError(t, err) - require.Len(t, logs, len(tc.expectedLogs)) + require.Len(t, store.logs, len(tc.expectedLogs)) for ind := range tc.expectedLogs { - var previous *core.Log - if ind > 0 { - previous = &tc.expectedLogs[ind-1] - } expectedLog := tc.expectedLogs[ind] switch v := expectedLog.Data.(type) { case core.NewTransactionLogPayload: @@ -221,7 +269,6 @@ func TestExecuteScript(t *testing.T) { expectedLog.Data = v } expectedLog.Date = now - require.Equal(t, expectedLog.ComputeHash(previous), logs[ind]) } require.Equal(t, tc.expectedTx.Timestamp, runner.state.GetMoreRecentTransactionDate()) diff --git a/pkg/ledger/stats_test.go b/pkg/ledger/stats_test.go deleted file mode 100644 index 45f8fdbca..000000000 --- a/pkg/ledger/stats_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package ledger - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestStats(t *testing.T) { - runOnLedger(t, func(l *Ledger) { - _, err := l.Stats(context.Background()) - assert.NoError(t, err) - }) -} diff --git a/pkg/ledgertesting/storage.go b/pkg/ledgertesting/storage.go deleted file mode 100644 index 7842c8a5c..000000000 --- a/pkg/ledgertesting/storage.go +++ /dev/null @@ -1,44 +0,0 @@ -package ledgertesting - -import ( - "context" - - "github.com/formancehq/ledger/pkg/storage" - "github.com/formancehq/ledger/pkg/storage/sqlstorage" - "github.com/formancehq/ledger/pkg/storage/sqlstorage/schema" - "github.com/formancehq/stack/libs/go-libs/pgtesting" - "github.com/stretchr/testify/require" - "go.uber.org/fx" -) - -func StorageDriver(t pgtesting.TestingT) *sqlstorage.Driver { - pgServer := pgtesting.NewPostgresDatabase(t) - - db, err := sqlstorage.OpenSQLDB(pgServer.ConnString()) - require.NoError(t, err) - - t.Cleanup(func() { - db.Close() - }) - - return sqlstorage.NewDriver("postgres", schema.NewPostgresDB(db)) -} - -func ProvideStorageDriver(t pgtesting.TestingT) fx.Option { - return fx.Provide(func(lc fx.Lifecycle) (storage.Driver, error) { - driver := StorageDriver(t) - lc.Append(fx.Hook{ - OnStart: driver.Initialize, - OnStop: func(ctx context.Context) error { - return driver.Close(ctx) - }, - }) - return driver, nil - }) -} - -func ProvideLedgerStorageDriver(t pgtesting.TestingT) fx.Option { - return fx.Options( - ProvideStorageDriver(t), - ) -} diff --git a/pkg/machine/machine_test.go b/pkg/machine/machine_test.go index 91c92541d..745c57298 100644 --- a/pkg/machine/machine_test.go +++ b/pkg/machine/machine_test.go @@ -8,12 +8,8 @@ import ( "testing" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/machine/script/compiler" "github.com/formancehq/ledger/pkg/machine/vm" - "github.com/formancehq/ledger/pkg/storage" - "github.com/formancehq/stack/libs/go-libs/pgtesting" - "github.com/google/uuid" "github.com/stretchr/testify/require" ) @@ -23,7 +19,7 @@ type testCase struct { vars map[string]json.RawMessage expectErrorCode error expectResult Result - setup func(t *testing.T, store storage.LedgerStore) + store vm.Store metadata core.Metadata } @@ -152,27 +148,36 @@ var testCases = []testCase{ }, { name: "using metadata", - setup: func(t *testing.T, store storage.LedgerStore) { - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "sales:001": { + store: vm.StaticStore{ + "sales:001": &core.AccountWithVolumes{ + Account: core.Account{ + Address: "sales:001", + Metadata: core.Metadata{ + "seller": json.RawMessage(`{ + "type": "account", + "value": "users:001" + }`), + }, + }, + Volumes: map[string]core.Volumes{ "COIN": { Input: big.NewInt(100), Output: big.NewInt(0), }, }, - })) - require.NoError(t, store.UpdateAccountMetadata(context.Background(), "sales:001", core.Metadata{ - "seller": json.RawMessage(`{ - "type": "account", - "value": "users:001" - }`), - })) - require.NoError(t, store.UpdateAccountMetadata(context.Background(), "users:001", core.Metadata{ - "commission": json.RawMessage(`{ - "type": "portion", - "value": "15.5%" - }`), - })) + }, + "users:001": &core.AccountWithVolumes{ + Account: core.Account{ + Address: "sales:001", + Metadata: core.Metadata{ + "commission": json.RawMessage(`{ + "type": "portion", + "value": "15.5%" + }`), + }, + }, + Volumes: map[string]core.Volumes{}, + }, }, script: ` vars { @@ -286,16 +291,19 @@ var testCases = []testCase{ }, { name: "balance function", - setup: func(t *testing.T, store storage.LedgerStore) { - require.NoError(t, store.EnsureAccountExists(context.Background(), "users:001")) - require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ - "users:001": map[string]core.Volumes{ + store: vm.StaticStore{ + "users:001": { + Account: core.Account{ + Address: "users:001", + Metadata: core.Metadata{}, + }, + Volumes: map[string]core.Volumes{ "COIN": { Input: big.NewInt(100), Output: big.NewInt(0), }, }, - })) + }, }, script: ` vars { @@ -330,8 +338,14 @@ var testCases = []testCase{ }, { name: "send amount 0", - setup: func(t *testing.T, store storage.LedgerStore) { - require.NoError(t, store.EnsureAccountExists(context.Background(), "alice")) + store: vm.StaticStore{ + "alice": { + Account: core.Account{ + Address: "alice", + Metadata: core.Metadata{}, + }, + Volumes: map[string]core.Volumes{}, + }, }, script: ` send [USD 0] ( @@ -348,8 +362,14 @@ var testCases = []testCase{ }, { name: "send all with balance 0", - setup: func(t *testing.T, store storage.LedgerStore) { - require.NoError(t, store.EnsureAccountExists(context.Background(), "alice")) + store: vm.StaticStore{ + "alice": { + Account: core.Account{ + Address: "alice", + Metadata: core.Metadata{}, + }, + Volumes: map[string]core.Volumes{}, + }, }, script: ` send [USD *] ( @@ -366,8 +386,14 @@ var testCases = []testCase{ }, { name: "send account balance of 0", - setup: func(t *testing.T, store storage.LedgerStore) { - require.NoError(t, store.EnsureAccountExists(context.Background(), "alice")) + store: vm.StaticStore{ + "alice": { + Account: core.Account{ + Address: "alice", + Metadata: core.Metadata{}, + }, + Volumes: map[string]core.Volumes{}, + }, }, script: ` vars { @@ -390,27 +416,12 @@ var testCases = []testCase{ func TestMachine(t *testing.T) { t.Parallel() - require.NoError(t, pgtesting.CreatePostgresServer()) - defer func() { - _ = pgtesting.DestroyPostgresServer() - }() - - storageDriver := ledgertesting.StorageDriver(t) - require.NoError(t, storageDriver.Initialize(context.Background())) - for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - ledger := uuid.NewString() - - store, _, err := storageDriver.GetLedgerStore(context.Background(), ledger, true) - require.NoError(t, err) - - _, err = store.Initialize(context.Background()) - require.NoError(t, err) - if tc.setup != nil { - tc.setup(t, store) + if tc.store == nil { + tc.store = vm.StaticStore{} } program, err := compiler.Compile(tc.script) @@ -419,9 +430,9 @@ func TestMachine(t *testing.T) { m := vm.NewMachine(*program) require.NoError(t, m.SetVarsFromJSON(tc.vars)) - _, _, err = m.ResolveResources(context.Background(), store) + _, _, err = m.ResolveResources(context.Background(), tc.store) require.NoError(t, err) - require.NoError(t, m.ResolveBalances(context.Background(), store)) + require.NoError(t, m.ResolveBalances(context.Background(), tc.store)) result, err := Run(m, core.RunScript{ Script: core.Script{ diff --git a/pkg/storage/sqlstorage/driver_test.go b/pkg/storage/sqlstorage/driver_test.go index 016041b27..35ab945f3 100644 --- a/pkg/storage/sqlstorage/driver_test.go +++ b/pkg/storage/sqlstorage/driver_test.go @@ -2,18 +2,33 @@ package sqlstorage_test import ( "context" + "os" "testing" - "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" "github.com/formancehq/ledger/pkg/storage/sqlstorage" ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/sqlstoragetesting" + "github.com/formancehq/stack/libs/go-libs/logging" + "github.com/formancehq/stack/libs/go-libs/pgtesting" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestMain(t *testing.M) { + if err := pgtesting.CreatePostgresServer(); err != nil { + logging.Error(err) + os.Exit(1) + } + code := t.Run() + if err := pgtesting.DestroyPostgresServer(); err != nil { + logging.Error(err) + } + os.Exit(code) +} + func TestNewDriver(t *testing.T) { - d := ledgertesting.StorageDriver(t) + d := sqlstoragetesting.StorageDriver(t) assert.NoError(t, d.Initialize(context.Background())) @@ -36,7 +51,7 @@ func TestNewDriver(t *testing.T) { } func TestConfiguration(t *testing.T) { - d := ledgertesting.StorageDriver(t) + d := sqlstoragetesting.StorageDriver(t) require.NoError(t, d.Initialize(context.Background())) @@ -47,7 +62,7 @@ func TestConfiguration(t *testing.T) { } func TestConfigurationError(t *testing.T) { - d := ledgertesting.StorageDriver(t) + d := sqlstoragetesting.StorageDriver(t) require.NoError(t, d.Initialize(context.Background())) diff --git a/pkg/storage/sqlstorage/ledger/accounts.go b/pkg/storage/sqlstorage/ledger/accounts.go index 0cb5f4599..595e38a61 100644 --- a/pkg/storage/sqlstorage/ledger/accounts.go +++ b/pkg/storage/sqlstorage/ledger/accounts.go @@ -39,6 +39,10 @@ type AccountsPaginationToken struct { BalanceOperatorFilter storage.BalanceOperator `json:"balanceOperator,omitempty"` } +func (t AccountsPaginationToken) Encode() string { + return encodePaginationToken(t) +} + func (s *Store) buildAccountsQuery(ctx context.Context, p storage.AccountsQuery) (*bun.SelectQuery, AccountsPaginationToken) { sb := s.schema.NewSelect(accountsTableName).Model((*Accounts)(nil)) t := AccountsPaginationToken{} diff --git a/pkg/storage/sqlstorage/ledger/accounts_test.go b/pkg/storage/sqlstorage/ledger/accounts_test.go index c1530e296..bf621088c 100644 --- a/pkg/storage/sqlstorage/ledger/accounts_test.go +++ b/pkg/storage/sqlstorage/ledger/accounts_test.go @@ -5,26 +5,13 @@ import ( "testing" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" - "github.com/formancehq/ledger/pkg/storage/sqlstorage" "github.com/stretchr/testify/assert" ) func TestAccounts(t *testing.T) { - d := ledgertesting.StorageDriver(t) - - assert.NoError(t, d.Initialize(context.Background())) - - defer func(d *sqlstorage.Driver, ctx context.Context) { - assert.NoError(t, d.Close(ctx)) - }(d, context.Background()) - - store, _, err := d.GetLedgerStore(context.Background(), "foo", true) - assert.NoError(t, err) - - _, err = store.Initialize(context.Background()) - assert.NoError(t, err) + t.Parallel() + store := newLedgerStore(t) t.Run("success balance", func(t *testing.T) { q := storage.AccountsQuery{ @@ -137,7 +124,7 @@ func TestAccounts(t *testing.T) { assert.Equal(t, addr, account.Address, "account address should match") }) - t.Run("success ensure mulitple accounts exist", func(t *testing.T) { + t.Run("success ensure multiple accounts exist", func(t *testing.T) { addrs := []string{"test:account:4", "test:account:5", "test:account:6"} err := store.EnsureAccountsExist(context.Background(), addrs) diff --git a/pkg/storage/sqlstorage/ledger/balances.go b/pkg/storage/sqlstorage/ledger/balances.go index 47c8f3241..1de655d7c 100644 --- a/pkg/storage/sqlstorage/ledger/balances.go +++ b/pkg/storage/sqlstorage/ledger/balances.go @@ -23,6 +23,10 @@ type BalancesPaginationToken struct { AddressRegexpFilter string `json:"address,omitempty"` } +func (t BalancesPaginationToken) Encode() string { + return encodePaginationToken(t) +} + func (s *Store) GetBalancesAggregated(ctx context.Context, q storage.BalancesQuery) (core.AssetsBalances, error) { if !s.isInitialized { return nil, storageerrors.StorageError(storage.ErrStoreNotInitialized) diff --git a/pkg/storage/sqlstorage/ledger/balances_test.go b/pkg/storage/sqlstorage/ledger/balances_test.go index 696e9615a..f5d475c57 100644 --- a/pkg/storage/sqlstorage/ledger/balances_test.go +++ b/pkg/storage/sqlstorage/ledger/balances_test.go @@ -11,7 +11,10 @@ import ( "github.com/stretchr/testify/require" ) -func testGetBalances(t *testing.T, store storage.LedgerStore) { +func TestGetBalances(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ "world": { "USD": core.NewEmptyVolumes().WithOutput(big.NewInt(200)), @@ -139,7 +142,10 @@ func testGetBalances(t *testing.T, store storage.LedgerStore) { }) } -func testGetBalancesAggregated(t *testing.T, store storage.LedgerStore) { +func TestGetBalancesAggregated(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ "world": { "USD": core.NewEmptyVolumes().WithOutput(big.NewInt(200)), diff --git a/pkg/storage/sqlstorage/ledger/logs.go b/pkg/storage/sqlstorage/ledger/logs.go index fbad2e116..cf4df8cb0 100644 --- a/pkg/storage/sqlstorage/ledger/logs.go +++ b/pkg/storage/sqlstorage/ledger/logs.go @@ -46,6 +46,10 @@ type LogsPaginationToken struct { EndTime core.Time `json:"endTime,omitempty"` } +func (t LogsPaginationToken) Encode() string { + return encodePaginationToken(t) +} + type RawMessage json.RawMessage func (j RawMessage) Value() (driver.Value, error) { diff --git a/pkg/storage/sqlstorage/ledger/main_test.go b/pkg/storage/sqlstorage/ledger/main_test.go index 9834a8944..5daeadf65 100644 --- a/pkg/storage/sqlstorage/ledger/main_test.go +++ b/pkg/storage/sqlstorage/ledger/main_test.go @@ -1,21 +1,50 @@ package ledger_test import ( + "context" "os" "testing" + "github.com/formancehq/ledger/pkg/storage" + "github.com/formancehq/ledger/pkg/storage/sqlstorage" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/schema" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/formancehq/stack/libs/go-libs/pgtesting" + "github.com/google/uuid" + "github.com/stretchr/testify/require" ) -func TestMain(t *testing.M) { +func TestMain(m *testing.M) { if err := pgtesting.CreatePostgresServer(); err != nil { logging.Error(err) os.Exit(1) } - code := t.Run() + + code := m.Run() if err := pgtesting.DestroyPostgresServer(); err != nil { logging.Error(err) } os.Exit(code) } + +func newLedgerStore(t *testing.T) storage.LedgerStore { + t.Helper() + + pgServer := pgtesting.NewPostgresDatabase(t) + db, err := sqlstorage.OpenSQLDB(pgServer.ConnString()) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + driver := sqlstorage.NewDriver("postgres", schema.NewPostgresDB(db)) + require.NoError(t, driver.Initialize(context.Background())) + + ledgerStore, _, err := driver.GetLedgerStore(context.Background(), uuid.NewString(), true) + require.NoError(t, err) + + _, err = ledgerStore.Initialize(context.Background()) + require.NoError(t, err) + + return ledgerStore +} diff --git a/pkg/storage/sqlstorage/ledger/migrates/13-clean-logs/any_test.go b/pkg/storage/sqlstorage/ledger/migrates/13-clean-logs/any_test.go index a19a5b95d..759662031 100644 --- a/pkg/storage/sqlstorage/ledger/migrates/13-clean-logs/any_test.go +++ b/pkg/storage/sqlstorage/ledger/migrates/13-clean-logs/any_test.go @@ -6,9 +6,9 @@ import ( "testing" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" "github.com/formancehq/ledger/pkg/storage/sqlstorage/migrations" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/sqlstoragetesting" "github.com/formancehq/stack/libs/go-libs/pgtesting" "github.com/pborman/uuid" "github.com/stretchr/testify/require" @@ -31,7 +31,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, pgtesting.DestroyPostgresServer()) }() - driver := ledgertesting.StorageDriver(t) + driver := sqlstoragetesting.StorageDriver(t) require.NoError(t, driver.Initialize(context.Background())) store, _, err := driver.GetLedgerStore(context.Background(), uuid.New(), true) diff --git a/pkg/storage/sqlstorage/ledger/migrates/17-optimized-segments/any_test.go b/pkg/storage/sqlstorage/ledger/migrates/17-optimized-segments/any_test.go index 9d4d7e6df..bdffaba2a 100644 --- a/pkg/storage/sqlstorage/ledger/migrates/17-optimized-segments/any_test.go +++ b/pkg/storage/sqlstorage/ledger/migrates/17-optimized-segments/any_test.go @@ -5,9 +5,9 @@ import ( "testing" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" "github.com/formancehq/ledger/pkg/storage/sqlstorage/migrations" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/sqlstoragetesting" "github.com/formancehq/stack/libs/go-libs/pgtesting" "github.com/pborman/uuid" "github.com/stretchr/testify/require" @@ -19,7 +19,7 @@ func TestMigrate17(t *testing.T) { require.NoError(t, pgtesting.DestroyPostgresServer()) }() - driver := ledgertesting.StorageDriver(t) + driver := sqlstoragetesting.StorageDriver(t) require.NoError(t, driver.Initialize(context.Background())) store, _, err := driver.GetLedgerStore(context.Background(), uuid.New(), true) diff --git a/pkg/storage/sqlstorage/ledger/migrates/9-add-pre-post-volumes/any_test.go b/pkg/storage/sqlstorage/ledger/migrates/9-add-pre-post-volumes/any_test.go index d541fc52b..281668523 100644 --- a/pkg/storage/sqlstorage/ledger/migrates/9-add-pre-post-volumes/any_test.go +++ b/pkg/storage/sqlstorage/ledger/migrates/9-add-pre-post-volumes/any_test.go @@ -11,12 +11,12 @@ import ( "testing" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" ledgerstore "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger" add_pre_post_volumes "github.com/formancehq/ledger/pkg/storage/sqlstorage/ledger/migrates/9-add-pre-post-volumes" "github.com/formancehq/ledger/pkg/storage/sqlstorage/migrations" "github.com/formancehq/ledger/pkg/storage/sqlstorage/schema" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/sqlstoragetesting" "github.com/formancehq/stack/libs/go-libs/pgtesting" "github.com/pborman/uuid" "github.com/stretchr/testify/require" @@ -226,7 +226,7 @@ func TestMigrate9(t *testing.T) { require.NoError(t, pgtesting.DestroyPostgresServer()) }() - driver := ledgertesting.StorageDriver(t) + driver := sqlstoragetesting.StorageDriver(t) require.NoError(t, driver.Initialize(context.Background())) store, _, err := driver.GetLedgerStore(context.Background(), uuid.New(), true) @@ -314,9 +314,9 @@ type Transactions struct { var addressQueryRegexp = regexp.MustCompile(`^(\w+|\*|\.\*)(:(\w+|\*|\.\*))*$`) -func buildTransactionsQuery(ctx context.Context, schema schema.Schema, p storage.TransactionsQuery) (*bun.SelectQuery, ledgerstore.TxsPaginationToken) { +func buildTransactionsQuery(ctx context.Context, schema schema.Schema, p storage.TransactionsQuery) (*bun.SelectQuery, ledgerstore.TransactionsPaginationToken) { sb := schema.NewSelect("transactions").Model((*Transactions)(nil)) - t := ledgerstore.TxsPaginationToken{} + t := ledgerstore.TransactionsPaginationToken{} var ( destination = p.Filters.Destination diff --git a/pkg/storage/sqlstorage/ledger/pagination.go b/pkg/storage/sqlstorage/ledger/pagination.go new file mode 100644 index 000000000..a416ef285 --- /dev/null +++ b/pkg/storage/sqlstorage/ledger/pagination.go @@ -0,0 +1,14 @@ +package ledger + +import ( + "encoding/base64" + "encoding/json" +) + +func encodePaginationToken(t any) string { + data, err := json.Marshal(t) + if err != nil { + panic(err) + } + return base64.RawURLEncoding.EncodeToString(data) +} diff --git a/pkg/storage/sqlstorage/ledger/store_test.go b/pkg/storage/sqlstorage/ledger/store_test.go index a6d565645..f934b4cbf 100644 --- a/pkg/storage/sqlstorage/ledger/store_test.go +++ b/pkg/storage/sqlstorage/ledger/store_test.go @@ -3,83 +3,15 @@ package ledger_test import ( "context" "encoding/json" - "fmt" "math/big" "testing" "time" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/fx" ) -func TestStore(t *testing.T) { - type testingFunction struct { - name string - fn func(t *testing.T, store storage.LedgerStore) - } - - for _, tf := range []testingFunction{ - {name: "UpdateTransactionMetadata", fn: testUpdateTransactionMetadata}, - {name: "UpdateAccountMetadata", fn: testUpdateAccountMetadata}, - {name: "GetLastLog", fn: testGetLastLog}, - {name: "GetLogs", fn: testGetLogs}, - {name: "CountAccounts", fn: testCountAccounts}, - {name: "GetAssetsVolumes", fn: testGetAssetsVolumes}, - {name: "GetAccounts", fn: testGetAccounts}, - {name: "GetAccountNotFound", fn: testGetAccountNotFound}, - {name: "Transactions", fn: testTransactions}, - {name: "GetTransaction", fn: testGetTransaction}, - {name: "GetBalances", fn: testGetBalances}, - {name: "GetBalancesAggregated", fn: testGetBalancesAggregated}, - } { - t.Run(fmt.Sprintf("postgres/%s", tf.name), func(t *testing.T) { - done := make(chan struct{}) - app := fx.New( - ledgertesting.ProvideStorageDriver(t), - fx.NopLogger, - fx.Invoke(func(driver storage.Driver, lc fx.Lifecycle) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - defer func() { - close(done) - }() - store, _, err := driver.GetLedgerStore(ctx, uuid.NewString(), true) - if err != nil { - return err - } - defer store.Close(ctx) - - if _, err = store.Initialize(context.Background()); err != nil { - return err - } - - tf.fn(t, store) - return nil - }, - }) - }), - ) - go func() { - require.NoError(t, app.Start(context.Background())) - }() - defer func(app *fx.App, ctx context.Context) { - require.NoError(t, app.Stop(ctx)) - }(app, context.Background()) - - select { - case <-time.After(5 * time.Second): - t.Fatal("timeout") - case <-done: - } - }) - } -} - var now = core.Now() var tx1 = core.ExpandedTransaction{ Transaction: core.Transaction{ @@ -219,7 +151,9 @@ var tx3 = core.ExpandedTransaction{ }, } -func testUpdateTransactionMetadata(t *testing.T, store storage.LedgerStore) { +func TestUpdateTransactionMetadata(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) tx := core.ExpandedTransaction{ Transaction: core.Transaction{ ID: 0, @@ -250,7 +184,10 @@ func testUpdateTransactionMetadata(t *testing.T, store storage.LedgerStore) { require.EqualValues(t, "bar", retrievedTransaction.Metadata["foo"]) } -func testUpdateAccountMetadata(t *testing.T, store storage.LedgerStore) { +func TestUpdateAccountMetadata(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + require.NoError(t, store.EnsureAccountExists(context.Background(), "central_bank")) err := store.UpdateAccountMetadata(context.Background(), "central_bank", core.Metadata{ @@ -263,13 +200,19 @@ func testUpdateAccountMetadata(t *testing.T, store storage.LedgerStore) { require.EqualValues(t, "bar", account.Metadata["foo"]) } -func testGetAccountNotFound(t *testing.T, store storage.LedgerStore) { +func TestGetAccountNotFound(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + account, err := store.GetAccount(context.Background(), "account_not_existing") require.True(t, storage.IsNotFoundError(err)) require.Nil(t, account) } -func testCountAccounts(t *testing.T, store storage.LedgerStore) { +func TestCountAccounts(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + require.NoError(t, store.EnsureAccountExists(context.Background(), "world")) require.NoError(t, store.EnsureAccountExists(context.Background(), "central_bank")) @@ -278,7 +221,10 @@ func testCountAccounts(t *testing.T, store storage.LedgerStore) { require.EqualValues(t, 2, countAccounts) // world + central_bank } -func testGetAssetsVolumes(t *testing.T, store storage.LedgerStore) { +func TestGetAssetsVolumes(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + require.NoError(t, store.UpdateVolumes(context.Background(), core.AccountsAssetsVolumes{ "central_bank": { "USD": { @@ -295,7 +241,10 @@ func testGetAssetsVolumes(t *testing.T, store storage.LedgerStore) { require.EqualValues(t, big.NewInt(0), volumes["USD"].Output) } -func testGetAccounts(t *testing.T, store storage.LedgerStore) { +func TestGetAccounts(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + require.NoError(t, store.UpdateAccountMetadata(context.Background(), "world", core.Metadata{ "foo": json.RawMessage(`"bar"`), })) @@ -380,167 +329,10 @@ func testGetAccounts(t *testing.T, store storage.LedgerStore) { require.Len(t, accounts.Data, 1) } -func testTransactions(t *testing.T, store storage.LedgerStore) { - require.NoError(t, store.InsertTransactions(context.Background(), tx1, tx2, tx3)) - - t.Run("Count", func(t *testing.T) { - count, err := store.CountTransactions(context.Background(), storage.TransactionsQuery{}) - require.NoError(t, err) - // Should get all the transactions - require.EqualValues(t, 3, count) - - count, err = store.CountTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Account: "world", - }, - }) - require.NoError(t, err) - // Should get the two first transactions involving the 'world' account. - require.EqualValues(t, 2, count) - - count, err = store.CountTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Account: "world", - StartTime: now.Add(-2 * time.Hour), - EndTime: now.Add(-1 * time.Hour), - }, - }) - require.NoError(t, err) - // Should get only tx2, as StartTime is inclusive and EndTime exclusive. - require.EqualValues(t, 1, count) - - count, err = store.CountTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Metadata: map[string]string{ - "priority": "high", - }, - }, - }) - require.NoError(t, err) - require.EqualValues(t, 1, count) - }) - - t.Run("Get", func(t *testing.T) { - cursor, err := store.GetTransactions(context.Background(), storage.TransactionsQuery{ - PageSize: 1, - }) - require.NoError(t, err) - // Should get only the first transaction. - require.Equal(t, 1, cursor.PageSize) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - AfterTxID: cursor.Data[0].ID, - PageSize: 1, - }) - require.NoError(t, err) - // Should get only the second transaction. - require.Equal(t, 1, cursor.PageSize) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Account: "world", - Reference: "tx1", - }, - PageSize: 1, - }) - require.NoError(t, err) - require.Equal(t, 1, cursor.PageSize) - // Should get only the first transaction. - require.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Account: "users:.*", - }, - PageSize: 10, - }) - require.NoError(t, err) - require.Equal(t, 10, cursor.PageSize) - require.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Source: "central_bank", - }, - PageSize: 10, - }) - require.NoError(t, err) - require.Equal(t, 10, cursor.PageSize) - // Should get only the third transaction. - require.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Destination: "users:1", - }, - PageSize: 10, - }) - require.NoError(t, err) - require.Equal(t, 10, cursor.PageSize) - // Should get only the third transaction. - require.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Destination: "users:.*", // Use regex - }, - PageSize: 10, - }) - assert.NoError(t, err) - assert.Equal(t, 10, cursor.PageSize) - // Should get only the third transaction. - assert.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Destination: ".*:1", // Use regex - }, - PageSize: 10, - }) - assert.NoError(t, err) - assert.Equal(t, 10, cursor.PageSize) - // Should get only the third transaction. - assert.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Source: ".*bank", // Use regex - }, - PageSize: 10, - }) - assert.NoError(t, err) - assert.Equal(t, 10, cursor.PageSize) - // Should get only the third transaction. - assert.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - StartTime: now.Add(-2 * time.Hour), - EndTime: now.Add(-1 * time.Hour), - }, - PageSize: 10, - }) - require.NoError(t, err) - require.Equal(t, 10, cursor.PageSize) - // Should get only tx2, as StartTime is inclusive and EndTime exclusive. - require.Len(t, cursor.Data, 1) - - cursor, err = store.GetTransactions(context.Background(), storage.TransactionsQuery{ - Filters: storage.TransactionsQueryFilters{ - Metadata: map[string]string{ - "priority": "high", - }, - }, - PageSize: 10, - }) - require.NoError(t, err) - require.Equal(t, 10, cursor.PageSize) - // Should get only the third transaction. - require.Len(t, cursor.Data, 1) - }) -} +func TestGetTransaction(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) -func testGetTransaction(t *testing.T, store storage.LedgerStore) { require.NoError(t, store.InsertTransactions(context.Background(), tx1, tx2)) tx, err := store.GetTransaction(context.Background(), tx1.ID) @@ -551,27 +343,18 @@ func testGetTransaction(t *testing.T, store storage.LedgerStore) { } func TestInitializeStore(t *testing.T) { - driver := ledgertesting.StorageDriver(t) - defer func(driver storage.Driver, ctx context.Context) { - require.NoError(t, driver.Close(ctx)) - }(driver, context.Background()) - - err := driver.Initialize(context.Background()) - require.NoError(t, err) - - store, _, err := driver.GetLedgerStore(context.Background(), uuid.NewString(), true) - require.NoError(t, err) + t.Parallel() + store := newLedgerStore(t) modified, err := store.Initialize(context.Background()) require.NoError(t, err) - require.True(t, modified) - - modified, err = store.Initialize(context.Background()) - require.NoError(t, err) require.False(t, modified) } -func testGetLastLog(t *testing.T, store storage.LedgerStore) { +func TestGetLastLog(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + lastLog, err := store.GetLastLog(context.Background()) require.True(t, storage.IsNotFoundError(err)) require.Nil(t, lastLog) @@ -588,7 +371,10 @@ func testGetLastLog(t *testing.T, store storage.LedgerStore) { require.Equal(t, tx1.Timestamp, lastLog.Data.(core.NewTransactionLogPayload).Transaction.Timestamp) } -func testGetLogs(t *testing.T, store storage.LedgerStore) { +func TestGetLogs(t *testing.T) { + t.Parallel() + store := newLedgerStore(t) + for _, tx := range []core.ExpandedTransaction{tx1, tx2, tx3} { logTx := core.NewTransactionLog(tx.Transaction, nil) require.NoError(t, store.AppendLog(context.Background(), &logTx)) diff --git a/pkg/storage/sqlstorage/ledger/transactions.go b/pkg/storage/sqlstorage/ledger/transactions.go index 2b0455a0b..b7626a80d 100644 --- a/pkg/storage/sqlstorage/ledger/transactions.go +++ b/pkg/storage/sqlstorage/ledger/transactions.go @@ -49,7 +49,7 @@ type Postings struct { Destination json.RawMessage `bun:"destination,type:jsonb"` } -type TxsPaginationToken struct { +type TransactionsPaginationToken struct { AfterTxID uint64 `json:"after"` ReferenceFilter string `json:"reference,omitempty"` AccountFilter string `json:"account,omitempty"` @@ -61,9 +61,13 @@ type TxsPaginationToken struct { PageSize uint `json:"pageSize,omitempty"` } -func (s *Store) buildTransactionsQuery(ctx context.Context, p storage.TransactionsQuery) (*bun.SelectQuery, TxsPaginationToken) { +func (t TransactionsPaginationToken) Encode() string { + return encodePaginationToken(t) +} + +func (s *Store) buildTransactionsQuery(ctx context.Context, p storage.TransactionsQuery) (*bun.SelectQuery, TransactionsPaginationToken) { sb := s.schema.NewSelect(TransactionsTableName).Model((*Transactions)(nil)) - t := TxsPaginationToken{} + t := TransactionsPaginationToken{} var ( destination = p.Filters.Destination diff --git a/pkg/storage/sqlstorage/ledger/transactions_test.go b/pkg/storage/sqlstorage/ledger/transactions_test.go index 9e0a6a2de..67958dc2e 100644 --- a/pkg/storage/sqlstorage/ledger/transactions_test.go +++ b/pkg/storage/sqlstorage/ledger/transactions_test.go @@ -7,26 +7,12 @@ import ( "time" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage" - "github.com/formancehq/ledger/pkg/storage/sqlstorage" "github.com/stretchr/testify/assert" ) func TestTransactions(t *testing.T) { - d := ledgertesting.StorageDriver(t) - - assert.NoError(t, d.Initialize(context.Background())) - - defer func(d *sqlstorage.Driver, ctx context.Context) { - assert.NoError(t, d.Close(ctx)) - }(d, context.Background()) - - store, _, err := d.GetLedgerStore(context.Background(), "foo", true) - assert.NoError(t, err) - - _, err = store.Initialize(context.Background()) - assert.NoError(t, err) + store := newLedgerStore(t) t.Run("success inserting transaction", func(t *testing.T) { tx1 := core.ExpandedTransaction{ diff --git a/pkg/storage/sqlstorage/ledger/volumes_test.go b/pkg/storage/sqlstorage/ledger/volumes_test.go index c90c02ade..bb95fd806 100644 --- a/pkg/storage/sqlstorage/ledger/volumes_test.go +++ b/pkg/storage/sqlstorage/ledger/volumes_test.go @@ -6,13 +6,13 @@ import ( "testing" "github.com/formancehq/ledger/pkg/core" - "github.com/formancehq/ledger/pkg/ledgertesting" "github.com/formancehq/ledger/pkg/storage/sqlstorage" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/sqlstoragetesting" "github.com/stretchr/testify/assert" ) func TestVolumes(t *testing.T) { - d := ledgertesting.StorageDriver(t) + d := sqlstoragetesting.StorageDriver(t) assert.NoError(t, d.Initialize(context.Background())) diff --git a/pkg/storage/sqlstorage/main_test.go b/pkg/storage/sqlstorage/main_test.go deleted file mode 100644 index 884ef4abf..000000000 --- a/pkg/storage/sqlstorage/main_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package sqlstorage_test - -import ( - "os" - "testing" - - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/formancehq/stack/libs/go-libs/pgtesting" -) - -func TestMain(t *testing.M) { - if err := pgtesting.CreatePostgresServer(); err != nil { - logging.Error(err) - os.Exit(1) - } - code := t.Run() - if err := pgtesting.DestroyPostgresServer(); err != nil { - logging.Error(err) - } - os.Exit(code) -} diff --git a/pkg/storage/sqlstorage/sqlstoragetesting/storage.go b/pkg/storage/sqlstorage/sqlstoragetesting/storage.go new file mode 100644 index 000000000..b89de256e --- /dev/null +++ b/pkg/storage/sqlstorage/sqlstoragetesting/storage.go @@ -0,0 +1,21 @@ +package sqlstoragetesting + +import ( + "github.com/formancehq/ledger/pkg/storage/sqlstorage" + "github.com/formancehq/ledger/pkg/storage/sqlstorage/schema" + "github.com/formancehq/stack/libs/go-libs/pgtesting" + "github.com/stretchr/testify/require" +) + +func StorageDriver(t pgtesting.TestingT) *sqlstorage.Driver { + pgServer := pgtesting.NewPostgresDatabase(t) + + db, err := sqlstorage.OpenSQLDB(pgServer.ConnString()) + require.NoError(t, err) + + t.Cleanup(func() { + db.Close() + }) + + return sqlstorage.NewDriver("postgres", schema.NewPostgresDB(db)) +} diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index f5b8f1e4f..2580d77b8 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -27,6 +27,9 @@ type TransactionsQueryFilters struct { func NewTransactionsQuery() *TransactionsQuery { return &TransactionsQuery{ PageSize: QueryDefaultPageSize, + Filters: TransactionsQueryFilters{ + Metadata: map[string]string{}, + }, } } @@ -142,6 +145,9 @@ func NewBalanceOperator(s string) (BalanceOperator, bool) { func NewAccountsQuery() *AccountsQuery { return &AccountsQuery{ PageSize: QueryDefaultPageSize, + Filters: AccountsQueryFilters{ + Metadata: map[string]string{}, + }, } }