Skip to content

Commit

Permalink
add sqlx prepare and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sado0823 committed Jan 19, 2023
1 parent 13a2560 commit 05c206a
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 4 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
)

require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/alicebob/miniredis/v2 v2.23.0
github.com/envoyproxy/protoc-gen-validate v0.1.0
github.com/go-playground/form/v4 v4.2.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
Expand Down
17 changes: 17 additions & 0 deletions kit/store/sqlx/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type (
Exec(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error)
QueryRow(ctx context.Context, v interface{}, query string, args ...interface{}) error
Query(ctx context.Context, v interface{}, query string, args ...interface{}) error
Prepare(ctx context.Context, query string) (StmtSession, error)
}

Conn interface {
Expand All @@ -87,6 +88,22 @@ func (c *conn) Close() error {
return c.db.Close()
}

func (c *conn) Prepare(ctx context.Context, query string) (stmt StmtSession, err error) {
startCtx, span := startSpan(ctx, "Prepare")
defer func() { endSpan(span, err) }()

err = c.brk.DoWithAcceptable(func() error {
sqlStmt, err := c.db.PrepareContext(startCtx, query)
if err != nil {
return err
}
stmt = &statement{stmt: sqlStmt, query: query}
return nil
}, acceptable)

return stmt, err
}

func (c *conn) Transaction(ctx context.Context, fn func(ctx context.Context, session Session) error) (err error) {
startCtx, span := startSpan(ctx, "Transaction")
defer func() { endSpan(span, err) }()
Expand Down
145 changes: 142 additions & 3 deletions kit/store/sqlx/mysql_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,120 @@
package sqlx

import (
"context"
"fmt"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)

func TestQueryWithError(t *testing.T) {
type city struct {
ID int64 `json:"id" db:"id"`
Name string `json:"name" db:"name"`
State int64 `json:"state" db:"state"`
Ctime string `json:"ctime" db:"ctime"`
Mtime string `json:"mtime" db:"mtime"`
}

needErr := fmt.Errorf("rows query with breaker error")

runSqlMockTest(t, func(ctx context.Context, conn Conn, mock sqlmock.Sqlmock) {
mock.ExpectQuery("select (.*) from `test`").
WillReturnError(needErr)

var records []*city
err := conn.Query(ctx, &records, "select * from `test`")
assert.EqualError(t, err, needErr.Error())
})
}

func TestQueryRows(t *testing.T) {
type city struct {
ID int64 `json:"id" db:"id"`
Name string `json:"name" db:"name"`
State int64 `json:"state" db:"state"`
Ctime string `json:"ctime" db:"ctime"`
Mtime string `json:"mtime" db:"mtime"`
}

runSqlMockTest(t, func(ctx context.Context, conn Conn, mock sqlmock.Sqlmock) {
rows := mock.NewRows([]string{"id", "name", "state", "ctime", "mtime"}).
AddRow(2, "bar", 2, "2021-01-02 12:11:10", "2022-01-01 12:11:10").
AddRow(1, "foo", 1, "2021-01-01 12:11:10", "2022-01-01 12:11:10")

mock.ExpectQuery("select (.*) from `test` order by id desc").
WillReturnRows(rows)

var records []*city
err := conn.Query(ctx, &records, "select * from `test` order by id desc")
assert.Nil(t, err)
assert.Equal(t, 2, len(records))
for i, record := range records {
t.Log(i, record)
}
})
}

func TestExecUpdate(t *testing.T) {
runSqlMockTest(t, func(ctx context.Context, conn Conn, mock sqlmock.Sqlmock) {
mock.ExpectExec("update test").
WithArgs("foo", 123).WillReturnResult(sqlmock.NewResult(0, 1))

result, err := conn.Exec(ctx, "update test set name = ? where name = ? and id = ?", "foo", 123)
assert.Nil(t, err)
affected, err := result.RowsAffected()
assert.Nil(t, err)
assert.Equal(t, int64(1), affected)
})
}

func TestExecInsert(t *testing.T) {
runSqlMockTest(t, func(ctx context.Context, conn Conn, mock sqlmock.Sqlmock) {
mock.ExpectExec("insert into `test`").
WithArgs("foo", 1, "bar", 2).WillReturnResult(sqlmock.NewResult(2, 2))

result, err := conn.Exec(ctx, "insert into `test` values (?,?),(?,?)", "foo", 1, "bar", 2)
assert.Nil(t, err)
affected, err := result.RowsAffected()
assert.Nil(t, err)
assert.Equal(t, int64(2), affected)
})
}

func TestQueryInt64(t *testing.T) {
runSqlMockTest(t, func(ctx context.Context, conn Conn, mock sqlmock.Sqlmock) {
rows := mock.NewRows([]string{"count"}).AddRow(2233)
mock.ExpectQuery("select (.*) as count from `test` where id > ?").
WithArgs(100).WillReturnRows(rows)

var val int64
err := conn.QueryRow(ctx, &val, "select count(*) as count from `test` where id > ?", 100)
assert.Nil(t, err)
assert.Equal(t, int64(2233), val)
fmt.Println(val)
})
}

func runSqlMockTest(t *testing.T, fn func(ctx context.Context, conn Conn, mock sqlmock.Sqlmock)) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("get sql mock failed, err:%+v", err)
}
with, err := NewWith(db)
if err != nil {
t.Fatalf("new with db err:%+v", err)
}
defer with.Close()

fn(context.Background(), with, mock)

if err = mock.ExpectationsWereMet(); err != nil {
t.Fatalf("sql mock not all expectations were met, err:%+v", err)
}
}

//type city struct {
// ID int64 `json:"id" db:"id"`
// Name string `json:"name" db:"name"`
Expand All @@ -19,12 +130,40 @@ func Test_Mysql(t *testing.T) {
//}
//
//records := make([]city, 0)
////records := make([]int64, 0)
//err = conn.Query(context.Background(), &records, "select * from test222")
//////records := make([]int64, 0)
////err = conn.Query(context.Background(), &records, "select * from test222")
////if err != nil {
//// t.Fatal(err)
////}
////
////for _, record := range records {
//// t.Log(record)
////}
//
////err = conn.Transaction(context.Background(), func(ctx context.Context, session Session) error {
//// res, err := session.Exec(ctx, "insert into `test222` (name,state) values(?,?)", "tx2", -12)
//// if err != nil {
//// return err
//// }
//// id, err := res.LastInsertId()
//// if err != nil {
//// return err
//// }
//// t.Log("got last insert id: ", id)
//// return fmt.Errorf("should tx rollback")
////})
////if err != nil {
//// t.Fatal(err)
////}
//
//prepare, err := conn.Prepare(context.Background(), "select * from test222 where id = ?")
//if err != nil {
// t.Fatal(err)
//}
//err = prepare.Query(context.Background(), &records, 3)
//if err != nil {
// t.Fatal(err)
//}
//
//for _, record := range records {
// t.Log(record)
//}
Expand Down
64 changes: 64 additions & 0 deletions kit/store/sqlx/stmt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package sqlx

import (
"context"
"database/sql"
)

type (
StmtSession interface {
Exec(ctx context.Context, args ...interface{}) (result sql.Result, err error)
QueryRow(ctx context.Context, v interface{}, args ...interface{}) error
Query(ctx context.Context, v interface{}, args ...interface{}) error
Close() error
}

statement struct {
query string
stmt *sql.Stmt
}
)

func (s *statement) Close() error {
return s.stmt.Close()
}

func (s *statement) Exec(ctx context.Context, args ...interface{}) (result sql.Result, err error) {
startCtx, span := startSpan(ctx, "Prepare Exec")
defer func() { endSpan(span, err) }()

return s.stmt.ExecContext(startCtx, args...)
}

func (s *statement) QueryRow(ctx context.Context, v interface{}, args ...interface{}) (err error) {
startCtx, span := startSpan(ctx, "Prepare QueryRow")
defer func() { endSpan(span, err) }()

return s.doQuery(startCtx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, args...)
}

func (s *statement) Query(ctx context.Context, v interface{}, args ...interface{}) (err error) {
startCtx, span := startSpan(ctx, "Prepare Query")
defer func() { endSpan(span, err) }()

return s.doQuery(startCtx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, args...)
}

func (s *statement) doQuery(ctx context.Context, scanner func(*sql.Rows) error, args ...interface{}) error {

rows, err := s.stmt.QueryContext(ctx, args...)
if err != nil {
return err
}
defer rows.Close()

if err = scanner(rows); err != nil {
return err
}

return rows.Err()
}
19 changes: 18 additions & 1 deletion kit/store/sqlx/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ var begin = func(db *sql.DB) (transactionI, error) {
return &transaction{tx}, nil
}

func (c *transaction) Prepare(ctx context.Context, query string) (stmt StmtSession, err error) {
startCtx, span := startSpan(ctx, "Transaction Prepare")
defer func() { endSpan(span, err) }()

var sqlStmt *sql.Stmt
sqlStmt, err = c.PrepareContext(startCtx, query)
if err != nil {
return nil, err
}

return &statement{stmt: sqlStmt, query: query}, nil
}

func (c *transaction) Exec(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
startCtx, span := startSpan(ctx, "Transaction Exec")
defer func() { endSpan(span, err) }()
Expand Down Expand Up @@ -58,5 +71,9 @@ func (c *transaction) query(ctx context.Context, scanner func(rows *sql.Rows) er
return err
}
defer rows.Close()
return scanner(rows)
if err = scanner(rows); err != nil {
return err
}

return rows.Err()
}

0 comments on commit 05c206a

Please sign in to comment.