diff --git a/ethtxmanager/sqlstorage/sqlstorage.go b/ethtxmanager/sqlstorage/sqlstorage.go index 17db429..310dc56 100644 --- a/ethtxmanager/sqlstorage/sqlstorage.go +++ b/ethtxmanager/sqlstorage/sqlstorage.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "strconv" "strings" "time" @@ -18,15 +17,6 @@ import ( ) const ( - // baseSelectQuery represents the base select query, that retrieves all the values from the monitored_txs table - baseSelectQuery = `SELECT id, from_address, to_address, nonce, value, tx_data, gas, gas_offset, gas_price, - blob_sidecar, blob_gas, blob_gas_price, gas_tip_cap, status, - block_number, history, created_at, updated_at, estimate_gas - FROM monitored_txs` - - // baseDeleteStatement represents the base delete statement that deletes all the records from the monitored_txs table - baseDeleteStatement = "DELETE FROM monitored_txs" - // monitoredTxsTable is table name for persisting MonitoredTx objects monitoredTxsTable = "monitored_txs" ) @@ -77,7 +67,7 @@ func (s *SqlStorage) Add(_ context.Context, mTx types.MonitoredTx) error { err := meddler.Insert(s.db, monitoredTxsTable, &mTx) if err != nil { - sqlErr, success := UnwrapSQLiteErr(err) + sqlErr, success := unwrapSQLiteErr(err) if !success { return err } @@ -93,9 +83,12 @@ func (s *SqlStorage) Add(_ context.Context, mTx types.MonitoredTx) error { // Remove deletes a monitored transaction from the database by its ID. // If the transaction does not exist, it returns an ErrNotFound error. func (s *SqlStorage) Remove(ctx context.Context, id common.Hash) error { - query := baseDeleteStatement + " WHERE id = $1" + baseDeleteStmt := buildBaseDeleteStatement(monitoredTxsTable) - result, err := s.db.ExecContext(ctx, query, id.Hex()) + var queryBuilder strings.Builder + queryBuilder.WriteString(baseDeleteStmt + " WHERE id = $1") + + result, err := s.db.ExecContext(ctx, queryBuilder.String(), id.Hex()) if err != nil { return err } @@ -116,11 +109,17 @@ func (s *SqlStorage) Remove(ctx context.Context, id common.Hash) error { // Get retrieves a monitored transaction from the database by its ID. // If the transaction is not found, it returns an ErrNotFound error. func (s *SqlStorage) Get(_ context.Context, id common.Hash) (types.MonitoredTx, error) { - query := baseSelectQuery + " WHERE id = $1" + var tx *types.MonitoredTx + baseQuery, err := buildBaseSelectQuery(tx, monitoredTxsTable) + if err != nil { + return types.MonitoredTx{}, err + } + + query := baseQuery + " WHERE id = $1" // Execute the query to retrieve the transaction data. var mTx types.MonitoredTx - err := meddler.QueryRow(s.db, &mTx, query, id.Hex()) + err = meddler.QueryRow(s.db, &mTx, query, id.Hex()) if err != nil { if err.Error() == errNoRowsInResultSet.Error() { return types.MonitoredTx{}, types.ErrNotFound @@ -136,20 +135,25 @@ func (s *SqlStorage) Get(_ context.Context, id common.Hash) (types.MonitoredTx, // If no statuses are provided, it returns all transactions. // The transactions are ordered by their creation date (oldest first). func (s *SqlStorage) GetByStatus(_ context.Context, statuses []types.MonitoredTxStatus) ([]types.MonitoredTx, error) { - query := baseSelectQuery + var tx *types.MonitoredTx + baseQuery, err := buildBaseSelectQuery(tx, monitoredTxsTable) + if err != nil { + return nil, err + } + + query := baseQuery args := make([]interface{}, 0, len(statuses)) if len(statuses) > 0 { + placeholders := make([]string, len(statuses)) // Build the WHERE clause for status filtering - query += " WHERE status IN (" for i, status := range statuses { - query += fmt.Sprintf("$%d", i+1) - if i != len(statuses)-1 { - query += ", " - } + placeholders[i] = fmt.Sprintf("$%d", i+1) args = append(args, string(status)) } - query += ")" + + // Build the WHERE clause with the joined placeholders + query += " WHERE status IN (" + strings.Join(placeholders, ", ") + ")" } // Add ordering by creation date (oldest first) @@ -166,7 +170,13 @@ func (s *SqlStorage) GetByStatus(_ context.Context, statuses []types.MonitoredTx // GetByBlock loads all monitored transactions that have the blockNumber between fromBlock and toBlock. func (s *SqlStorage) GetByBlock(ctx context.Context, fromBlock, toBlock *uint64) ([]types.MonitoredTx, error) { - query := baseSelectQuery + var tx *types.MonitoredTx + baseQuery, err := buildBaseSelectQuery(tx, monitoredTxsTable) + if err != nil { + return nil, err + } + + query := baseQuery const maxArgs = 2 args := make([]interface{}, 0, maxArgs) @@ -188,7 +198,7 @@ func (s *SqlStorage) GetByBlock(ctx context.Context, fromBlock, toBlock *uint64) // Use meddler.QueryAll to execute the query and scan into the result slice. var monitoredTxs []*types.MonitoredTx - err := meddler.QueryAll(s.db, &monitoredTxs, query, args...) + err = meddler.QueryAll(s.db, &monitoredTxs, query, args...) if err != nil { return nil, fmt.Errorf("failed to query monitored transactions by block: %w", err) } @@ -201,27 +211,27 @@ func (s *SqlStorage) Update(ctx context.Context, mTx types.MonitoredTx) error { mTx.UpdatedAt = time.Now() columns, err := meddler.Columns(&mTx, false) - if err != nil { + if err != nil || len(columns) == 0 { return fmt.Errorf("failed to build the update statement (column names resolution failed): %w", err) } - // Use strings.Builder instead of fmt.Sprintf for safer query building - var builder strings.Builder - builder.WriteString("UPDATE " + monitoredTxsTable + " SET ") - - // Build the SET clause - // Skip the first column (primary key) - for i, column := range columns[1:] { - if i > 0 { - builder.WriteString(", ") - } - builder.WriteString(column + " = $" + strconv.Itoa(i+1)) + placeholders, err := meddler.Placeholders(&mTx, false) + if err != nil || len(placeholders) == 0 { + return fmt.Errorf("failed to build the update statement (placeholders resolution failed): %w", err) } - // Add the WHERE clause for the primary key - builder.WriteString(" WHERE id = $" + strconv.Itoa(len(columns))) + placeholdersWithoutPK := placeholders[1:] - query := builder.String() + // Use strings.Builder for efficient query building + var queryBuilder strings.Builder + queryBuilder.WriteString("UPDATE " + monitoredTxsTable + " SET ") + + // Build the SET clause (skip the first column) + setClauses := make([]string, len(columns)-1) + for i, column := range columns[1:] { + setClauses[i] = column + " = " + placeholdersWithoutPK[i] + } + queryBuilder.WriteString(strings.Join(setClauses, ", ") + " WHERE id = " + placeholders[len(placeholders)-1]) args, err := meddler.Values(&mTx, false) if err != nil { @@ -236,7 +246,7 @@ func (s *SqlStorage) Update(ctx context.Context, mTx types.MonitoredTx) error { args = append(args[1:], mTx.ID.Hex()) // Execute the query with the arguments - result, err := s.db.ExecContext(ctx, query, args...) + result, err := s.db.ExecContext(ctx, queryBuilder.String(), args...) if err != nil { return fmt.Errorf("failed to update monitored transaction: %w", err) } @@ -255,7 +265,7 @@ func (s *SqlStorage) Update(ctx context.Context, mTx types.MonitoredTx) error { // Empty clears all the records from the monitored_txs table. func (s *SqlStorage) Empty(ctx context.Context) error { - _, err := s.db.ExecContext(ctx, baseDeleteStatement) + _, err := s.db.ExecContext(ctx, buildBaseDeleteStatement(monitoredTxsTable)) if err != nil { return fmt.Errorf("failed to empty monitored_txs table: %w", err) } @@ -263,7 +273,25 @@ func (s *SqlStorage) Empty(ctx context.Context) error { return nil } -// UnwrapSQLiteErr attempts to extract a *sqlite.Error from the given error. +// buildBaseSelectQuery creates SELECT query dynamically based on the provided entity and table name +func buildBaseSelectQuery(src interface{}, tableName string) (string, error) { + var queryBuilder strings.Builder + cols, err := meddler.Columns(src, false) + if err != nil { + return "", err + } + + queryBuilder.WriteString("SELECT " + strings.Join(cols, ", ") + " FROM " + tableName) + + return queryBuilder.String(), nil +} + +// buildBaseDeleteStatement creates DELETE statement dynamically based on the provided table name +func buildBaseDeleteStatement(tableName string) string { + return "DELETE FROM " + tableName +} + +// unwrapSQLiteErr attempts to extract a *sqlite.Error from the given error. // It first checks if the error is directly of type *sqlite.Error, and if not, // it tries to unwrap it from a meddler.DriverErr. // @@ -273,7 +301,7 @@ func (s *SqlStorage) Empty(ctx context.Context) error { // Returns: // - *sqlite.Error: The extracted SQLite error, or nil if not found. // - bool: True if the error was successfully unwrapped as a *sqlite.Error. -func UnwrapSQLiteErr(err error) (*sqlite.Error, bool) { +func unwrapSQLiteErr(err error) (*sqlite.Error, bool) { sqliteErr := &sqlite.Error{} if ok := errors.As(err, sqliteErr); ok { return sqliteErr, true diff --git a/ethtxmanager/sqlstorage/sqlstorage_test.go b/ethtxmanager/sqlstorage/sqlstorage_test.go index 2629738..579d935 100644 --- a/ethtxmanager/sqlstorage/sqlstorage_test.go +++ b/ethtxmanager/sqlstorage/sqlstorage_test.go @@ -2,7 +2,9 @@ package sqlstorage import ( "context" + "fmt" "math/big" + "sync" "testing" "time" @@ -107,7 +109,7 @@ func TestSqlStorage_Remove(t *testing.T) { func TestSqlStorage_Get(t *testing.T) { ctx := context.Background() - storage, err := NewStorage("sqlite3", ":memory:") + storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:") require.NoError(t, err) defer storage.db.Close() @@ -151,7 +153,7 @@ func TestSqlStorage_Get(t *testing.T) { func TestSqlStorage_GetByStatus(t *testing.T) { ctx := context.Background() - storage, err := NewStorage("sqlite3", ":memory:") + storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:") require.NoError(t, err) defer storage.db.Close() @@ -203,7 +205,7 @@ func TestSqlStorage_GetByStatus(t *testing.T) { func TestSqlStorage_GetByBlock(t *testing.T) { ctx := context.Background() - storage, err := NewStorage("sqlite3", ":memory:") + storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:") require.NoError(t, err) defer storage.db.Close() @@ -261,7 +263,7 @@ func TestSqlStorage_Update(t *testing.T) { ctx := context.Background() // Setup a temporary SQLite database for testing - storage, err := NewStorage("sqlite3", ":memory:") + storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:") require.NoError(t, err) defer storage.db.Close() @@ -321,14 +323,7 @@ func TestSqlStorage_Update(t *testing.T) { // Verify that the transaction was updated correctly updatedTx, err := storage.Get(ctx, test.updateTx.ID) require.NoError(t, err) - require.Equal(t, test.updateTx.From, updatedTx.From) - require.Equal(t, test.updateTx.To, updatedTx.To) - require.Equal(t, test.updateTx.Value, updatedTx.Value) - require.Equal(t, test.updateTx.Data, updatedTx.Data) - require.Equal(t, test.updateTx.Gas, updatedTx.Gas) - require.Equal(t, test.updateTx.GasPrice, updatedTx.GasPrice) - require.Equal(t, test.updateTx.Status, updatedTx.Status) - require.Equal(t, test.updateTx.BlockNumber, updatedTx.BlockNumber) + compareTxsWithoutDates(t, test.updateTx, updatedTx) } }) } @@ -338,7 +333,7 @@ func TestSqlStorage_Empty(t *testing.T) { ctx := context.Background() // Setup a temporary SQLite database for testing - storage, err := NewStorage("sqlite3", ":memory:") + storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:") require.NoError(t, err) defer storage.db.Close() @@ -364,6 +359,18 @@ func TestSqlStorage_Empty(t *testing.T) { _, err = storage.Get(ctx, tx2.ID) require.ErrorIs(t, err, types.ErrNotFound) } +func TestSqlStorage_MonitoredTxTableExists(t *testing.T) { + storage, err := NewStorage(localCommon.SQLLiteDriverName, ":memory:") + require.NoError(t, err) + defer storage.db.Close() + + // Check if the monitored_txs table exists + query := `SELECT name FROM sqlite_master WHERE type='table' AND name='monitored_txs';` + var tableName string + err = storage.db.QueryRow(query).Scan(&tableName) + require.NoError(t, err) + require.Equal(t, "monitored_txs", tableName) +} // Helper function to create a MonitoredTx for testing func newMonitoredTx(idHex string, fromHex string, toHex string, nonce uint64, status types.MonitoredTxStatus, blockNumber int64) types.MonitoredTx {