Skip to content

Commit

Permalink
refactor: build base select and base delete statement dynamically (ad…
Browse files Browse the repository at this point in the history
…dress comment from @arnaubennassar)
  • Loading branch information
Stefan-Ethernal committed Oct 2, 2024
1 parent db40fc9 commit d136818
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 56 deletions.
114 changes: 71 additions & 43 deletions ethtxmanager/sqlstorage/sqlstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"

Expand All @@ -18,15 +17,6 @@ import (
)

const (
// baseSelectQuery represents the base select query, that retrieves all the values from the monitored_txs table
baseSelectQuery = `SELECT id, from_address, to_address, nonce, value, tx_data, gas, gas_offset, gas_price,
blob_sidecar, blob_gas, blob_gas_price, gas_tip_cap, status,
block_number, history, created_at, updated_at, estimate_gas
FROM monitored_txs`

// baseDeleteStatement represents the base delete statement that deletes all the records from the monitored_txs table
baseDeleteStatement = "DELETE FROM monitored_txs"

// monitoredTxsTable is table name for persisting MonitoredTx objects
monitoredTxsTable = "monitored_txs"
)
Expand Down Expand Up @@ -77,7 +67,7 @@ func (s *SqlStorage) Add(_ context.Context, mTx types.MonitoredTx) error {

err := meddler.Insert(s.db, monitoredTxsTable, &mTx)
if err != nil {
sqlErr, success := UnwrapSQLiteErr(err)
sqlErr, success := unwrapSQLiteErr(err)
if !success {
return err
}
Expand All @@ -93,9 +83,12 @@ func (s *SqlStorage) Add(_ context.Context, mTx types.MonitoredTx) error {
// Remove deletes a monitored transaction from the database by its ID.
// If the transaction does not exist, it returns an ErrNotFound error.
func (s *SqlStorage) Remove(ctx context.Context, id common.Hash) error {
query := baseDeleteStatement + " WHERE id = $1"
baseDeleteStmt := buildBaseDeleteStatement(monitoredTxsTable)

result, err := s.db.ExecContext(ctx, query, id.Hex())
var queryBuilder strings.Builder
queryBuilder.WriteString(baseDeleteStmt + " WHERE id = $1")

result, err := s.db.ExecContext(ctx, queryBuilder.String(), id.Hex())
if err != nil {
return err
}
Expand All @@ -116,11 +109,17 @@ func (s *SqlStorage) Remove(ctx context.Context, id common.Hash) error {
// Get retrieves a monitored transaction from the database by its ID.
// If the transaction is not found, it returns an ErrNotFound error.
func (s *SqlStorage) Get(_ context.Context, id common.Hash) (types.MonitoredTx, error) {
query := baseSelectQuery + " WHERE id = $1"
var tx *types.MonitoredTx
baseQuery, err := buildBaseSelectQuery(tx, monitoredTxsTable)
if err != nil {
return types.MonitoredTx{}, err
}

query := baseQuery + " WHERE id = $1"

// Execute the query to retrieve the transaction data.
var mTx types.MonitoredTx
err := meddler.QueryRow(s.db, &mTx, query, id.Hex())
err = meddler.QueryRow(s.db, &mTx, query, id.Hex())
if err != nil {
if err.Error() == errNoRowsInResultSet.Error() {
return types.MonitoredTx{}, types.ErrNotFound
Expand All @@ -136,20 +135,25 @@ func (s *SqlStorage) Get(_ context.Context, id common.Hash) (types.MonitoredTx,
// If no statuses are provided, it returns all transactions.
// The transactions are ordered by their creation date (oldest first).
func (s *SqlStorage) GetByStatus(_ context.Context, statuses []types.MonitoredTxStatus) ([]types.MonitoredTx, error) {
query := baseSelectQuery
var tx *types.MonitoredTx
baseQuery, err := buildBaseSelectQuery(tx, monitoredTxsTable)
if err != nil {
return nil, err
}

query := baseQuery
args := make([]interface{}, 0, len(statuses))

if len(statuses) > 0 {
placeholders := make([]string, len(statuses))
// Build the WHERE clause for status filtering
query += " WHERE status IN ("
for i, status := range statuses {
query += fmt.Sprintf("$%d", i+1)
if i != len(statuses)-1 {
query += ", "
}
placeholders[i] = fmt.Sprintf("$%d", i+1)
args = append(args, string(status))
}
query += ")"

// Build the WHERE clause with the joined placeholders
query += " WHERE status IN (" + strings.Join(placeholders, ", ") + ")"
}

// Add ordering by creation date (oldest first)
Expand All @@ -166,7 +170,13 @@ func (s *SqlStorage) GetByStatus(_ context.Context, statuses []types.MonitoredTx

// GetByBlock loads all monitored transactions that have the blockNumber between fromBlock and toBlock.
func (s *SqlStorage) GetByBlock(ctx context.Context, fromBlock, toBlock *uint64) ([]types.MonitoredTx, error) {
query := baseSelectQuery
var tx *types.MonitoredTx
baseQuery, err := buildBaseSelectQuery(tx, monitoredTxsTable)
if err != nil {
return nil, err
}

query := baseQuery
const maxArgs = 2

args := make([]interface{}, 0, maxArgs)
Expand All @@ -188,7 +198,7 @@ func (s *SqlStorage) GetByBlock(ctx context.Context, fromBlock, toBlock *uint64)

// Use meddler.QueryAll to execute the query and scan into the result slice.
var monitoredTxs []*types.MonitoredTx
err := meddler.QueryAll(s.db, &monitoredTxs, query, args...)
err = meddler.QueryAll(s.db, &monitoredTxs, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query monitored transactions by block: %w", err)
}
Expand All @@ -201,27 +211,27 @@ func (s *SqlStorage) Update(ctx context.Context, mTx types.MonitoredTx) error {
mTx.UpdatedAt = time.Now()

columns, err := meddler.Columns(&mTx, false)
if err != nil {
if err != nil || len(columns) == 0 {
return fmt.Errorf("failed to build the update statement (column names resolution failed): %w", err)
}

// Use strings.Builder instead of fmt.Sprintf for safer query building
var builder strings.Builder
builder.WriteString("UPDATE " + monitoredTxsTable + " SET ")

// Build the SET clause
// Skip the first column (primary key)
for i, column := range columns[1:] {
if i > 0 {
builder.WriteString(", ")
}
builder.WriteString(column + " = $" + strconv.Itoa(i+1))
placeholders, err := meddler.Placeholders(&mTx, false)
if err != nil || len(placeholders) == 0 {
return fmt.Errorf("failed to build the update statement (placeholders resolution failed): %w", err)
}

// Add the WHERE clause for the primary key
builder.WriteString(" WHERE id = $" + strconv.Itoa(len(columns)))
placeholdersWithoutPK := placeholders[1:]

query := builder.String()
// Use strings.Builder for efficient query building
var queryBuilder strings.Builder
queryBuilder.WriteString("UPDATE " + monitoredTxsTable + " SET ")

// Build the SET clause (skip the first column)
setClauses := make([]string, len(columns)-1)
for i, column := range columns[1:] {
setClauses[i] = column + " = " + placeholdersWithoutPK[i]
}
queryBuilder.WriteString(strings.Join(setClauses, ", ") + " WHERE id = " + placeholders[len(placeholders)-1])

args, err := meddler.Values(&mTx, false)
if err != nil {
Expand All @@ -236,7 +246,7 @@ func (s *SqlStorage) Update(ctx context.Context, mTx types.MonitoredTx) error {
args = append(args[1:], mTx.ID.Hex())

// Execute the query with the arguments
result, err := s.db.ExecContext(ctx, query, args...)
result, err := s.db.ExecContext(ctx, queryBuilder.String(), args...)
if err != nil {
return fmt.Errorf("failed to update monitored transaction: %w", err)
}
Expand All @@ -255,15 +265,33 @@ func (s *SqlStorage) Update(ctx context.Context, mTx types.MonitoredTx) error {

// Empty clears all the records from the monitored_txs table.
func (s *SqlStorage) Empty(ctx context.Context) error {
_, err := s.db.ExecContext(ctx, baseDeleteStatement)
_, err := s.db.ExecContext(ctx, buildBaseDeleteStatement(monitoredTxsTable))
if err != nil {
return fmt.Errorf("failed to empty monitored_txs table: %w", err)
}

return nil
}

// UnwrapSQLiteErr attempts to extract a *sqlite.Error from the given error.
// buildBaseSelectQuery creates SELECT query dynamically based on the provided entity and table name
func buildBaseSelectQuery(src interface{}, tableName string) (string, error) {
var queryBuilder strings.Builder
cols, err := meddler.Columns(src, false)
if err != nil {
return "", err
}

queryBuilder.WriteString("SELECT " + strings.Join(cols, ", ") + " FROM " + tableName)

return queryBuilder.String(), nil
}

// buildBaseDeleteStatement creates DELETE statement dynamically based on the provided table name
func buildBaseDeleteStatement(tableName string) string {
return "DELETE FROM " + tableName
}

// unwrapSQLiteErr attempts to extract a *sqlite.Error from the given error.
// It first checks if the error is directly of type *sqlite.Error, and if not,
// it tries to unwrap it from a meddler.DriverErr.
//
Expand All @@ -273,7 +301,7 @@ func (s *SqlStorage) Empty(ctx context.Context) error {
// Returns:
// - *sqlite.Error: The extracted SQLite error, or nil if not found.
// - bool: True if the error was successfully unwrapped as a *sqlite.Error.
func UnwrapSQLiteErr(err error) (*sqlite.Error, bool) {
func unwrapSQLiteErr(err error) (*sqlite.Error, bool) {
sqliteErr := &sqlite.Error{}
if ok := errors.As(err, sqliteErr); ok {
return sqliteErr, true
Expand Down
33 changes: 20 additions & 13 deletions ethtxmanager/sqlstorage/sqlstorage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package sqlstorage

import (
"context"
"fmt"

Check failure on line 5 in ethtxmanager/sqlstorage/sqlstorage_test.go

View workflow job for this annotation

GitHub Actions / lint

"fmt" imported and not used

Check failure on line 5 in ethtxmanager/sqlstorage/sqlstorage_test.go

View workflow job for this annotation

GitHub Actions / test-unit (1.21.x, amd64)

"fmt" imported and not used
"math/big"
"sync"

Check failure on line 7 in ethtxmanager/sqlstorage/sqlstorage_test.go

View workflow job for this annotation

GitHub Actions / lint

"sync" imported and not used (typecheck)

Check failure on line 7 in ethtxmanager/sqlstorage/sqlstorage_test.go

View workflow job for this annotation

GitHub Actions / test-unit (1.21.x, amd64)

"sync" imported and not used
"testing"
"time"

Expand Down Expand Up @@ -107,7 +109,7 @@ func TestSqlStorage_Remove(t *testing.T) {
func TestSqlStorage_Get(t *testing.T) {
ctx := context.Background()

storage, err := NewStorage("sqlite3", ":memory:")
storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:")
require.NoError(t, err)
defer storage.db.Close()

Expand Down Expand Up @@ -151,7 +153,7 @@ func TestSqlStorage_Get(t *testing.T) {
func TestSqlStorage_GetByStatus(t *testing.T) {
ctx := context.Background()

storage, err := NewStorage("sqlite3", ":memory:")
storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:")
require.NoError(t, err)
defer storage.db.Close()

Expand Down Expand Up @@ -203,7 +205,7 @@ func TestSqlStorage_GetByStatus(t *testing.T) {
func TestSqlStorage_GetByBlock(t *testing.T) {
ctx := context.Background()

storage, err := NewStorage("sqlite3", ":memory:")
storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:")
require.NoError(t, err)
defer storage.db.Close()

Expand Down Expand Up @@ -261,7 +263,7 @@ func TestSqlStorage_Update(t *testing.T) {
ctx := context.Background()

// Setup a temporary SQLite database for testing
storage, err := NewStorage("sqlite3", ":memory:")
storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:")
require.NoError(t, err)
defer storage.db.Close()

Expand Down Expand Up @@ -321,14 +323,7 @@ func TestSqlStorage_Update(t *testing.T) {
// Verify that the transaction was updated correctly
updatedTx, err := storage.Get(ctx, test.updateTx.ID)
require.NoError(t, err)
require.Equal(t, test.updateTx.From, updatedTx.From)
require.Equal(t, test.updateTx.To, updatedTx.To)
require.Equal(t, test.updateTx.Value, updatedTx.Value)
require.Equal(t, test.updateTx.Data, updatedTx.Data)
require.Equal(t, test.updateTx.Gas, updatedTx.Gas)
require.Equal(t, test.updateTx.GasPrice, updatedTx.GasPrice)
require.Equal(t, test.updateTx.Status, updatedTx.Status)
require.Equal(t, test.updateTx.BlockNumber, updatedTx.BlockNumber)
compareTxsWithoutDates(t, test.updateTx, updatedTx)
}
})
}
Expand All @@ -338,7 +333,7 @@ func TestSqlStorage_Empty(t *testing.T) {
ctx := context.Background()

// Setup a temporary SQLite database for testing
storage, err := NewStorage("sqlite3", ":memory:")
storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:")
require.NoError(t, err)
defer storage.db.Close()

Expand All @@ -364,6 +359,18 @@ func TestSqlStorage_Empty(t *testing.T) {
_, err = storage.Get(ctx, tx2.ID)
require.ErrorIs(t, err, types.ErrNotFound)
}
func TestSqlStorage_MonitoredTxTableExists(t *testing.T) {
storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:")
require.NoError(t, err)
defer storage.db.Close()

// Check if the monitored_txs table exists
query := `SELECT name FROM sqlite_master WHERE type='table' AND name='monitored_txs';`
var tableName string
err = storage.db.QueryRow(query).Scan(&tableName)
require.NoError(t, err)
require.Equal(t, "monitored_txs", tableName)
}

// Helper function to create a MonitoredTx for testing
func newMonitoredTx(idHex string, fromHex string, toHex string, nonce uint64, status types.MonitoredTxStatus, blockNumber int64) types.MonitoredTx {
Expand Down

0 comments on commit d136818

Please sign in to comment.