diff --git a/go.mod b/go.mod index ed47ce212..f5db502aa 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,14 @@ go 1.16 require ( github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.1 github.com/gin-contrib/cors v1.3.1 github.com/gin-contrib/logger v0.2.0 github.com/gin-gonic/gin v1.7.7 github.com/go-openapi/spec v0.20.4 // indirect github.com/google/go-cmp v0.5.6 github.com/huandu/go-sqlbuilder v1.13.0 + github.com/jackc/pgconn v1.10.1 github.com/jackc/pgx/v4 v4.14.1 github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-sqlite3 v1.14.9 diff --git a/pkg/api/controllers/transaction_controller.go b/pkg/api/controllers/transaction_controller.go index 6b2627982..93399a6ac 100644 --- a/pkg/api/controllers/transaction_controller.go +++ b/pkg/api/controllers/transaction_controller.go @@ -2,6 +2,7 @@ package controllers import ( "errors" + "github.com/numary/ledger/pkg/storage" "net/http" "github.com/gin-gonic/gin" @@ -74,6 +75,14 @@ func (ctl *TransactionController) PostTransaction(c *gin.Context) { ts, err := l.(*ledger.Ledger).Commit(c, []core.Transaction{t}) if err != nil { + switch eerr := err.(type) { + case *storage.Error: + switch eerr.Code { + case storage.ConstraintFailed: + ctl.responseError(c, http.StatusConflict, err) + return + } + } ctl.responseError( c, http.StatusInternalServerError, diff --git a/pkg/storage/sqlstorage/accounts.go b/pkg/storage/sqlstorage/accounts.go index 6f665185d..06c596dac 100644 --- a/pkg/storage/sqlstorage/accounts.go +++ b/pkg/storage/sqlstorage/accounts.go @@ -39,7 +39,7 @@ func (s *Store) FindAccounts(ctx context.Context, q query.Query) (query.Cursor, ) if err != nil { - return c, err + return c, s.error(err) } for rows.Next() { diff --git a/pkg/storage/sqlstorage/aggregations.go b/pkg/storage/sqlstorage/aggregations.go index a288f820f..3faba2ec7 100644 --- a/pkg/storage/sqlstorage/aggregations.go +++ b/pkg/storage/sqlstorage/aggregations.go @@ -17,7 +17,7 @@ func (s *Store) CountTransactions(ctx context.Context) (int64, error) { err := s.db.QueryRowContext(ctx, sqlq, args...).Scan(&count) - return count, err + return count, s.error(err) } func (s *Store) CountAccounts(ctx context.Context) (int64, error) { @@ -34,7 +34,7 @@ func (s *Store) CountAccounts(ctx context.Context) (int64, error) { err := s.db.QueryRowContext(ctx, sqlq, args...).Scan(&count) - return count, err + return count, s.error(err) } func (s *Store) CountMeta(ctx context.Context) (int64, error) { @@ -52,7 +52,7 @@ func (s *Store) CountMeta(ctx context.Context) (int64, error) { q := s.db.QueryRowContext(ctx, sqlq, args...) err := q.Scan(&count) - return count, err + return count, s.error(err) } func (s *Store) AggregateBalances(ctx context.Context, address string) (map[string]int64, error) { @@ -61,7 +61,7 @@ func (s *Store) AggregateBalances(ctx context.Context, address string) (map[stri volumes, err := s.AggregateVolumes(ctx, address) if err != nil { - return balances, err + return balances, s.error(err) } for asset := range volumes { @@ -97,7 +97,7 @@ func (s *Store) AggregateVolumes(ctx context.Context, address string) (map[strin rows, err := s.db.QueryContext(ctx, sqlq, args...) if err != nil { - return volumes, err + return volumes, s.error(err) } for rows.Next() { @@ -110,7 +110,7 @@ func (s *Store) AggregateVolumes(ctx context.Context, address string) (map[strin err := rows.Scan(&row.asset, &row.t, &row.amount) if err != nil { - return volumes, err + return volumes, s.error(err) } if _, ok := volumes[row.asset]; !ok { diff --git a/pkg/storage/sqlstorage/driver.go b/pkg/storage/sqlstorage/driver.go index 7ab87a290..3c2dd6d69 100644 --- a/pkg/storage/sqlstorage/driver.go +++ b/pkg/storage/sqlstorage/driver.go @@ -5,17 +5,9 @@ import ( "database/sql" "errors" "fmt" - "github.com/huandu/go-sqlbuilder" "github.com/numary/ledger/pkg/storage" ) -type Flavor = sqlbuilder.Flavor - -var ( - SQLite = sqlbuilder.SQLite - PostgreSQL = sqlbuilder.PostgreSQL -) - var sqlDrivers = map[Flavor]struct { driverName string }{ diff --git a/pkg/storage/sqlstorage/flavor.go b/pkg/storage/sqlstorage/flavor.go new file mode 100644 index 000000000..980984585 --- /dev/null +++ b/pkg/storage/sqlstorage/flavor.go @@ -0,0 +1,41 @@ +package sqlstorage + +import ( + "github.com/huandu/go-sqlbuilder" + "github.com/jackc/pgconn" + "github.com/mattn/go-sqlite3" + "github.com/numary/ledger/pkg/storage" +) + +type Flavor = sqlbuilder.Flavor + +var ( + SQLite = sqlbuilder.SQLite + PostgreSQL = sqlbuilder.PostgreSQL +) + +func errorFromFlavor(f Flavor, err error) error { + if err == nil { + return nil + } + switch f { + case SQLite: + eerr, ok := err.(sqlite3.Error) + if !ok { + return err + } + if eerr.Code == sqlite3.ErrConstraint { + return storage.NewError(storage.ConstraintFailed, err) + } + case PostgreSQL: + eerr, ok := err.(*pgconn.PgError) + if !ok { + return err + } + switch eerr.Code { + case "23505": + return storage.NewError(storage.ConstraintFailed, err) + } + } + return err +} diff --git a/pkg/storage/sqlstorage/metadata.go b/pkg/storage/sqlstorage/metadata.go index 5371cf968..bc9752a08 100644 --- a/pkg/storage/sqlstorage/metadata.go +++ b/pkg/storage/sqlstorage/metadata.go @@ -11,7 +11,7 @@ import ( func (s *Store) LastMetaID(ctx context.Context) (int64, error) { count, err := s.CountMeta(ctx) if err != nil { - return 0, err + return 0, s.error(err) } return count - 1, nil } @@ -36,7 +36,7 @@ func (s *Store) GetMeta(ctx context.Context, ty string, id string) (core.Metadat rows, err := s.db.QueryContext(ctx, sqlq, args...) if err != nil { - return nil, err + return nil, s.error(err) } meta := core.Metadata{} @@ -51,7 +51,7 @@ func (s *Store) GetMeta(ctx context.Context, ty string, id string) (core.Metadat ) if err != nil { - return nil, err + return nil, s.error(err) } var value json.RawMessage @@ -59,7 +59,7 @@ func (s *Store) GetMeta(ctx context.Context, ty string, id string) (core.Metadat err = json.Unmarshal([]byte(metaValue), &value) if err != nil { - return nil, err + return nil, s.error(err) } meta[metaKey] = value @@ -71,7 +71,7 @@ func (s *Store) GetMeta(ctx context.Context, ty string, id string) (core.Metadat func (s *Store) SaveMeta(ctx context.Context, id int64, timestamp, targetType, targetID, key, value string) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return err + return s.error(err) } ib := sqlbuilder.NewInsertBuilder() @@ -101,7 +101,7 @@ func (s *Store) SaveMeta(ctx context.Context, id int64, timestamp, targetType, t logrus.Debugln("failed to save metadata", err) tx.Rollback() - return err + return s.error(err) } err = tx.Commit() diff --git a/pkg/storage/sqlstorage/store.go b/pkg/storage/sqlstorage/store.go index ea6481deb..4bf1c4585 100644 --- a/pkg/storage/sqlstorage/store.go +++ b/pkg/storage/sqlstorage/store.go @@ -26,13 +26,17 @@ type Store struct { func (s *Store) table(name string) string { switch s.flavor { - case sqlbuilder.PostgreSQL: + case PostgreSQL: return fmt.Sprintf(`"%s"."%s"`, s.ledger, name) default: return name } } +func (s *Store) error(err error) error { + return errorFromFlavor(s.flavor, err) +} + func NewStore(name string, flavor sqlbuilder.Flavor, db *sql.DB, onClose func(ctx context.Context) error) (*Store, error) { return &Store{ ledger: name, @@ -56,7 +60,7 @@ func (s *Store) Initialize(ctx context.Context) error { entries, err := migrations.ReadDir(migrationsDir) if err != nil { - return err + return s.error(err) } for _, m := range entries { @@ -82,7 +86,7 @@ func (s *Store) Initialize(ctx context.Context) error { if err != nil { err = fmt.Errorf("failed to run statement %d: %w", i, err) logrus.Errorln(err) - return err + return s.error(err) } } diff --git a/pkg/storage/sqlstorage/store_test.go b/pkg/storage/sqlstorage/store_test.go index db92ea26f..d176109dd 100644 --- a/pkg/storage/sqlstorage/store_test.go +++ b/pkg/storage/sqlstorage/store_test.go @@ -62,6 +62,10 @@ func TestStore(t *testing.T) { name: "SaveTransactions", fn: testSaveTransaction, }, + { + name: "DuplicatedTransaction", + fn: testDuplicatedTransaction, + }, { name: "SaveMeta", fn: testSaveMeta, @@ -167,6 +171,24 @@ func testSaveTransaction(t *testing.T, store storage.Store) { assert.NoError(t, err) } +func testDuplicatedTransaction(t *testing.T, store storage.Store) { + txs := []core.Transaction{ + { + Postings: []core.Posting{ + {}, + }, + Reference: "foo", + }, + } + err := store.SaveTransactions(context.Background(), txs) + assert.NoError(t, err) + + err = store.SaveTransactions(context.Background(), txs) + assert.Error(t, err) + assert.IsType(t, &storage.Error{}, err) + assert.Equal(t, storage.ConstraintFailed, err.(*storage.Error).Code) +} + func testSaveMeta(t *testing.T, store storage.Store) { err := store.SaveMeta(context.Background(), 1, time.Now().Format(time.RFC3339), "transaction", "1", "firstname", "\"YYY\"") diff --git a/pkg/storage/sqlstorage/transactions.go b/pkg/storage/sqlstorage/transactions.go index 5217759e4..253766e45 100644 --- a/pkg/storage/sqlstorage/transactions.go +++ b/pkg/storage/sqlstorage/transactions.go @@ -68,7 +68,7 @@ func (s *Store) FindTransactions(ctx context.Context, q query.Query) (query.Curs ) if err != nil { - return c, err + return c, s.error(err) } transactions := map[int64]core.Transaction{} @@ -114,7 +114,7 @@ func (s *Store) FindTransactions(ctx context.Context, q query.Query) (query.Curs for _, t := range transactions { meta, err := s.GetMeta(ctx, "transaction", fmt.Sprintf("%d", t.ID)) if err != nil { - return c, err + return c, s.error(err) } t.Metadata = meta @@ -143,7 +143,7 @@ func (s *Store) SaveTransactions(ctx context.Context, ts []core.Transaction) err tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return err + return s.error(err) } for _, t := range ts { @@ -163,7 +163,7 @@ func (s *Store) SaveTransactions(ctx context.Context, ts []core.Transaction) err if err != nil { tx.Rollback() - return err + return s.error(err) } for i, p := range t.Postings { @@ -179,7 +179,7 @@ func (s *Store) SaveTransactions(ctx context.Context, ts []core.Transaction) err if err != nil { tx.Rollback() - return err + return s.error(err) } } @@ -187,7 +187,7 @@ func (s *Store) SaveTransactions(ctx context.Context, ts []core.Transaction) err if err != nil { tx.Rollback() - return err + return s.error(err) } for key, value := range t.Metadata { @@ -217,7 +217,7 @@ func (s *Store) SaveTransactions(ctx context.Context, ts []core.Transaction) err if err != nil { tx.Rollback() - return err + return s.error(err) } nextID++ @@ -254,7 +254,7 @@ func (s *Store) GetTransaction(ctx context.Context, txid string) (tx core.Transa ) if err != nil { - return tx, err + return tx, s.error(err) } for rows.Next() { @@ -290,7 +290,7 @@ func (s *Store) GetTransaction(ctx context.Context, txid string) (tx core.Transa meta, err := s.GetMeta(ctx, "transaction", fmt.Sprintf("%d", tx.ID)) if err != nil { - return tx, err + return tx, s.error(err) } tx.Metadata = meta @@ -305,7 +305,7 @@ func (s *Store) LastTransaction(ctx context.Context) (*core.Transaction, error) c, err := s.FindTransactions(ctx, q) if err != nil { - return nil, err + return nil, s.error(err) } txs := (c.Data).([]core.Transaction) diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 07efb493a..1c57db6f5 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -2,10 +2,33 @@ package storage import ( "context" + "fmt" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger/query" ) +type Code string + +const ( + ConstraintFailed Code = "CONSTRAINT_FAILED" +) + +type Error struct { + Code Code + OriginalError error +} + +func (e Error) Error() string { + return fmt.Sprintf("%s [%s]", e.OriginalError, e.Code) +} + +func NewError(code Code, originalError error) *Error { + return &Error{ + Code: code, + OriginalError: originalError, + } +} + type Store interface { LastTransaction(context.Context) (*core.Transaction, error) LastMetaID(context.Context) (int64, error)