Skip to content

Commit

Permalink
updated query interceptor hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ganigeorgiev committed Feb 22, 2023
1 parent 707df5d commit f9efbf6
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 321 deletions.
118 changes: 97 additions & 21 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ import (
"time"
)

// ExecHookFunc executes before op allowing custom handling like auto fail/retry.
type ExecHookFunc func(q *Query, op func() error) error

// OneHookFunc executes right before the query populate the row result from One() call (aka. op).
type OneHookFunc func(q *Query, a interface{}, op func(b interface{}) error) error

// AllHookFunc executes right before the query populate the row result from All() call (aka. op).
type AllHookFunc func(q *Query, sliceA interface{}, op func(sliceB interface{}) error) error

// Params represents a list of parameter values to be bound to a SQL statement.
// The map keys are the parameter names while the map values are the corresponding parameter values.
type Params map[string]interface{}
Expand Down Expand Up @@ -44,6 +53,11 @@ type Query struct {
stmt *sql.Stmt
ctx context.Context

// hooks
execHook ExecHookFunc
oneHook OneHookFunc
allHook AllHookFunc

// FieldMapper maps struct field names to DB column names.
FieldMapper FieldMapFunc
// LastError contains the last error (if any) of the query.
Expand Down Expand Up @@ -96,6 +110,31 @@ func (q *Query) WithContext(ctx context.Context) *Query {
return q
}

// WithExecHook associates the provided exec hook function with the query.
//
// It is called for every Query resolver (Execute(), One(), All(), Row(), Column()),
// allowing you to implement auto fail/retry or any other additional handling.
func (q *Query) WithExecHook(fn ExecHookFunc) *Query {
q.execHook = fn
return q
}

// WithOneHook associates the provided hook function with the query,
// called on q.One(), allowing you to implement custom struct scan based
// on the One() argument and/or result.
func (q *Query) WithOneHook(fn OneHookFunc) *Query {
q.oneHook = fn
return q
}

// WithOneHook associates the provided hook function with the query,
// called on q.All(), allowing you to implement custom slice scan based
// on the All() argument and/or result.
func (q *Query) WithAllHook(fn AllHookFunc) *Query {
q.allHook = fn
return q
}

// logSQL returns the SQL statement with parameters being replaced with the actual values.
// The result is only for logging purpose and should not be used to execute.
func (q *Query) logSQL() string {
Expand Down Expand Up @@ -160,7 +199,19 @@ func (q *Query) Bind(params Params) *Query {
}

// Execute executes the SQL statement without retrieving data.
func (q *Query) Execute() (result sql.Result, err error) {
func (q *Query) Execute() (sql.Result, error) {
var result sql.Result

execErr := q.execWrap(func() error {
var err error
result, err = q.execute()
return err
})

return result, execErr
}

func (q *Query) execute() (result sql.Result, err error) {
err = q.LastError
q.LastError = nil
if err != nil {
Expand Down Expand Up @@ -206,44 +257,62 @@ func (q *Query) Execute() (result sql.Result, err error) {
// the variable to be populated.
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
func (q *Query) One(a interface{}) error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.one(a)
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}

if q.oneHook != nil {
return q.oneHook(q, a, rows.one)
}

return rows.one(a)
})
}

// All executes the SQL statement and populates all the resulting rows into a slice of struct or NullStringMap.
// The slice must be given as a pointer. Each slice element must be either a struct or a NullStringMap.
// Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how each slice element can be.
// If the query returns no row, the slice will be an empty slice (not nil).
func (q *Query) All(slice interface{}) error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.all(slice)
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}

if q.allHook != nil {
return q.allHook(q, slice, rows.all)
}

return rows.all(slice)
})
}

// Row executes the SQL statement and populates the first row of the result into a list of variables.
// Note that the number of the variables should match to that of the columns in the query result.
// Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned.
func (q *Query) Row(a ...interface{}) error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.row(a...)
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.row(a...)
})
}

// Column executes the SQL statement and populates the first column of the result into a slice.
// Note that the parameter must be a pointer to a slice.
func (q *Query) Column(a interface{}) error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.column(a)
return q.execWrap(func() error {
rows, err := q.Rows()
if err != nil {
return err
}
return rows.column(a)
})
}

// Rows executes the SQL statement and returns a Rows object to allow retrieving data row by row.
Expand Down Expand Up @@ -290,6 +359,13 @@ func (q *Query) Rows() (rows *Rows, err error) {
return
}

func (q *Query) execWrap(op func() error) error {
if q.execHook != nil {
return q.execHook(q, op)
}
return op()
}

// replacePlaceholders converts a list of named parameters into a list of anonymous parameters.
func replacePlaceholders(placeholders []string, params Params) ([]interface{}, error) {
if len(placeholders) == 0 {
Expand Down
199 changes: 199 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package dbx
import (
ss "database/sql"
"encoding/json"
"errors"
"testing"
"time"

Expand Down Expand Up @@ -399,3 +400,201 @@ func TestIssue13(t *testing.T) {
assert.NotZero(t, user2.ID)
}
}

func TestQueryWithExecHook(t *testing.T) {
db := getPreparedDB()
defer db.Close()

// error return
{
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
return errors.New("test")
}).
Row()

assert.Error(t, err)
}

// Row()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
Row()
assert.Nil(t, err)
assert.Equal(t, 1, calls, "Row()")
}

// One()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
One(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "One()")
}

// All()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
All(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "All()")
}

// Column()
{
calls := 0
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
Column(nil)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "Column()")
}

// Execute()
{
calls := 0
_, err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
calls++
return nil
}).
Execute()
assert.Nil(t, err)
assert.Equal(t, 1, calls, "Execute()")
}

// op call
{
calls := 0
var id int
err := db.NewQuery("select id from user where id = 2").
WithExecHook(func(q *Query, op func() error) error {
calls++
return op()
}).
Row(&id)
assert.Nil(t, err)
assert.Equal(t, 1, calls, "op hook calls")
assert.Equal(t, 2, id, "id mismatch")
}
}

func TestQueryWithOneHook(t *testing.T) {
db := getPreparedDB()
defer db.Close()

// error return
{
err := db.NewQuery("select * from user").
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
return errors.New("test")
}).
One(nil)

assert.Error(t, err)
}

// hooks call order
{
hookCalls := []string{}
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
hookCalls = append(hookCalls, "exec")
return op()
}).
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
hookCalls = append(hookCalls, "one")
return nil
}).
One(nil)

assert.Nil(t, err)
assert.Equal(t, hookCalls, []string{"exec", "one"})
}

// op call
{
calls := 0
other := User{}
err := db.NewQuery("select id from user where id = 2").
WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
calls++
return op(&other)
}).
One(nil)

assert.Nil(t, err)
assert.Equal(t, 1, calls, "hook calls")
assert.Equal(t, int64(2), other.ID, "replaced scan struct")
}
}

func TestQueryWithAllHook(t *testing.T) {
db := getPreparedDB()
defer db.Close()

// error return
{
err := db.NewQuery("select * from user").
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
return errors.New("test")
}).
All(nil)

assert.Error(t, err)
}

// hooks call order
{
hookCalls := []string{}
err := db.NewQuery("select * from user").
WithExecHook(func(q *Query, op func() error) error {
hookCalls = append(hookCalls, "exec")
return op()
}).
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
hookCalls = append(hookCalls, "all")
return nil
}).
All(nil)

assert.Nil(t, err)
assert.Equal(t, hookCalls, []string{"exec", "all"})
}

// op call
{
calls := 0
other := []User{}
err := db.NewQuery("select id from user order by id asc").
WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error {
calls++
return op(&other)
}).
All(nil)

assert.Nil(t, err)
assert.Equal(t, 1, calls, "hook calls")
assert.Equal(t, 2, len(other), "users length")
assert.Equal(t, int64(1), other[0].ID, "user 1 id check")
assert.Equal(t, int64(2), other[1].ID, "user 2 id check")
}
}
Loading

0 comments on commit f9efbf6

Please sign in to comment.