Skip to content

Commit

Permalink
support postgres sql quotes,fix #15
Browse files Browse the repository at this point in the history
  • Loading branch information
fifsky committed Oct 16, 2019
1 parent df0bbda commit beb871f
Show file tree
Hide file tree
Showing 19 changed files with 305 additions and 188 deletions.
7 changes: 4 additions & 3 deletions APIDESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ gosql.Use(db).Model(&Users{}}).Get()

## Transaction context switching
```go
gosql.WithTx(tx *sqlx.Tx)
gosql.WithTx(tx).Table("xxxx").Where("id = ?",1).Get(&user)
gosql.WithTx(tx).Model(&Users{}).Get()
gosql.Tx(func(tx *gosql.DB){
tx.Table("xxxx").Where("id = ?",1).Get(&user)
tx.Model(&Users{}).Get()
})
```
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The package based on [sqlx](https://github.com/jmoiron/sqlx), It's simple and ke
- Remove Model interface DbName() function,use the Use() function
- Uniform API design specification, see [APIDESIGN](APIDESIGN.md)
- Relation add `connection:"db2"` struct tag, Solve the cross-library connection problem caused by deleting DbName()
- Discard the WithTx function

## Usage

Expand All @@ -39,8 +40,7 @@ func main(){

//connection database
gosql.Connect(configs)

gosql.DB().QueryRowx("select * from users where id = 1")
gosql.QueryRowx("select * from users where id = 1")
}

```
Expand Down Expand Up @@ -156,16 +156,16 @@ gosql.Model(&user).Get("status")
The `Tx` function has a callback function, if an error is returned, the transaction rollback

```go
gosql.Tx(func(tx *sqlx.Tx) error {
gosql.Tx(func(tx *gosql.DB) error {
for id := 1; id < 10; id++ {
user := &Users{
Id: id,
Name: "test" + strconv.Itoa(id),
Email: "test" + strconv.Itoa(id) + "@test.com",
}

//v2
gosql.WithTx(tx).Model(user).Create()
//v2 support, do some database operations in the transaction (use 'tx' from this point, not 'gosql')
tx.Model(user).Create()

if id == 8 {
return errors.New("interrupt the transaction")
Expand All @@ -174,7 +174,7 @@ gosql.Tx(func(tx *sqlx.Tx) error {

//query with transaction
var num int
err := gosql.WithTx(tx).QueryRowx("select count(*) from user_id = 1").Scan(&num)
err := tx.QueryRowx("select count(*) from user_id = 1").Scan(&num)

if err != nil {
return err
Expand Down Expand Up @@ -234,8 +234,8 @@ gosql.Table("users").Where("id = ?", 1).Count()
//Change database
gosql.Use("db2").Table("users").Where("id = ?", 1).Count()

//Transaction `tx` is *sqlx.Tx for v2
gosql.WithTx(tx).Table("users").Where("id = ?", 1}).Count()
//Transaction `tx`
tx.Table("users").Where("id = ?", 1}).Count()
```


Expand Down
9 changes: 4 additions & 5 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var dbService = make(map[string]*sqlx.DB, 0)

// DB gets the specified database engine,
// or the default DB if no name is specified.
func DB(name ...string) *sqlx.DB {
func Sqlx(name ...string) *sqlx.DB {
dbName := defaultLink
if name != nil {
dbName = name[0]
Expand Down Expand Up @@ -73,11 +73,10 @@ func Connect(configs map[string]*Config) (err error) {
}

if db, ok := dbService[key]; ok {
dbService[key] = sess
db.Close()
} else {
dbService[key] = sess
_ = db.Close()
}

dbService[key] = sess
}
return
}
4 changes: 2 additions & 2 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ func TestMain(m *testing.M) {
ShowSql: true,
}

Connect(configs)
_ = Connect(configs)

m.Run()
}

func TestConnect(t *testing.T) {
db := DB()
db := Sqlx()

if db.DriverName() != "mysql" {
t.Fatalf("sqlx database connection error")
Expand Down
106 changes: 51 additions & 55 deletions wrapper.go → db.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,35 @@ type ISqlx interface {
Select(dest interface{}, query string, args ...interface{}) error
Exec(query string, args ...interface{}) (sql.Result, error)
Rebind(query string) string
DriverName() string
}

var (
defaultWrapper = Use(defaultLink)
)

type BuilderChainFunc func(b *Builder)
type BuilderChainFunc func(b *ModelStruct)

type Wrapper struct {
database string
type DB struct {
database *sqlx.DB
tx *sqlx.Tx
logging bool
RelationMap map[string]BuilderChainFunc
}

func (w *Wrapper) db() ISqlx {
// return database instance, if it is a transaction, the transaction priority is higher
func (w *DB) db() ISqlx {
if w.tx != nil {
return w.tx.Unsafe()
}

return DB(w.database).Unsafe()
return w.database.Unsafe()
}

func ShowSql() *Wrapper {
// ShowSql single show sql log
func ShowSql() *DB {
w := Use(defaultLink)
w.logging = true
return w
}

func (w *Wrapper) argsIn(query string, args []interface{}) (string, []interface{}, error) {
func (w *DB) argsIn(query string, args []interface{}) (string, []interface{}, error) {
newArgs := make([]interface{}, 0)
newQuery, newArgs, err := sqlx.In(query, args...)

Expand All @@ -61,13 +60,22 @@ func (w *Wrapper) argsIn(query string, args []interface{}) (string, []interface{
return newQuery, newArgs, nil
}

//DriverName wrapper sqlx.DriverName
func (w *DB) DriverName() string {
if w.tx != nil {
return w.tx.DriverName()
}

return w.database.DriverName()
}

//Rebind wrapper sqlx.Rebind
func (w *Wrapper) Rebind(query string) string {
func (w *DB) Rebind(query string) string {
return w.db().Rebind(query)
}

//Exec wrapper sqlx.Exec
func (w *Wrapper) Exec(query string, args ...interface{}) (result sql.Result, err error) {
func (w *DB) Exec(query string, args ...interface{}) (result sql.Result, err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Query: query,
Expand All @@ -83,7 +91,7 @@ func (w *Wrapper) Exec(query string, args ...interface{}) (result sql.Result, er
}

//Queryx wrapper sqlx.Queryx
func (w *Wrapper) Queryx(query string, args ...interface{}) (rows *sqlx.Rows, err error) {
func (w *DB) Queryx(query string, args ...interface{}) (rows *sqlx.Rows, err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Query: query,
Expand All @@ -103,7 +111,7 @@ func (w *Wrapper) Queryx(query string, args ...interface{}) (rows *sqlx.Rows, er
}

//QueryRowx wrapper sqlx.QueryRowx
func (w *Wrapper) QueryRowx(query string, args ...interface{}) (rows *sqlx.Row) {
func (w *DB) QueryRowx(query string, args ...interface{}) (rows *sqlx.Row) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Query: query,
Expand All @@ -120,7 +128,7 @@ func (w *Wrapper) QueryRowx(query string, args ...interface{}) (rows *sqlx.Row)
}

//Get wrapper sqlx.Get
func (w *Wrapper) Get(dest interface{}, query string, args ...interface{}) (err error) {
func (w *DB) Get(dest interface{}, query string, args ...interface{}) (err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Query: query,
Expand Down Expand Up @@ -170,7 +178,7 @@ func indirectType(v reflect.Type) reflect.Type {
}

//Select wrapper sqlx.Select
func (w *Wrapper) Select(dest interface{}, query string, args ...interface{}) (err error) {
func (w *DB) Select(dest interface{}, query string, args ...interface{}) (err error) {
defer func(start time.Time) {
logger.Log(&QueryStatus{
Query: query,
Expand Down Expand Up @@ -207,13 +215,12 @@ func (w *Wrapper) Select(dest interface{}, query string, args ...interface{}) (e
}

//Txx the transaction with context
func (w *Wrapper) Txx(ctx context.Context, fn func(ctx context.Context, tx *sqlx.Tx) error) (err error) {
db := DB(w.database)
tx, err := db.BeginTxx(ctx, nil)
func (w *DB) Txx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) (err error) {
tx, err := w.database.BeginTxx(ctx, nil)

if err != nil {
return err
}
tx = tx.Unsafe()
defer func() {
if err != nil {
err := tx.Rollback()
Expand All @@ -223,21 +230,19 @@ func (w *Wrapper) Txx(ctx context.Context, fn func(ctx context.Context, tx *sqlx
}
}()

err = fn(ctx, tx)
err = fn(ctx, &DB{tx: tx})
if err == nil {
err = tx.Commit()
}
return
}

//Tx the transaction
func (w *Wrapper) Tx(fn func(tx *sqlx.Tx) error) (err error) {
db := DB(w.database)
tx, err := db.Beginx()
func (w *DB) Tx(fn func(w *DB) error) (err error) {
tx, err := w.database.Beginx()
if err != nil {
return err
}
tx = tx.Unsafe()
defer func() {
if err != nil {
err := tx.Rollback()
Expand All @@ -246,8 +251,7 @@ func (w *Wrapper) Tx(fn func(tx *sqlx.Tx) error) (err error) {
}
}
}()

err = fn(tx)
err = fn(&DB{tx: tx})
if err == nil {
err = tx.Commit()
}
Expand All @@ -257,21 +261,19 @@ func (w *Wrapper) Tx(fn func(tx *sqlx.Tx) error) (err error) {
// Table database handler from to table name
// for example:
// gosql.Use("db2").Table("users")
// gosql.WithTx(tx).Table("users").Get()
func (w *Wrapper) Table(t string) *Mapper {
return &Mapper{wrapper: w, SQLBuilder: SQLBuilder{table: t}}
func (w *DB) Table(t string) *Mapper {
return &Mapper{db: w, SQLBuilder: SQLBuilder{table: t, dialect: newDialect(w.DriverName())}}
}

//Model database handler from to struct
//for example:
// gosql.Use("db2").Model(&users{})
// gosql.WithTx(tx).Model(&users{}).Get()
func (w *Wrapper) Model(m interface{}) *Builder {
return &Builder{model: m, wrapper: w}
func (w *DB) Model(m interface{}) *ModelStruct {
return &ModelStruct{model: m, db: w, SQLBuilder: SQLBuilder{dialect: newDialect(w.DriverName())}}
}

//Import SQL DDL from sql file
func (w *Wrapper) Import(f string) ([]sql.Result, error) {
func (w *DB) Import(f string) ([]sql.Result, error) {
file, err := os.Open(f)
if err != nil {
return nil, err
Expand Down Expand Up @@ -313,7 +315,7 @@ func (w *Wrapper) Import(f string) ([]sql.Result, error) {
}

// Relation association table builder handle
func (w *Wrapper) Relation(name string, fn BuilderChainFunc) *Wrapper {
func (w *DB) Relation(name string, fn BuilderChainFunc) *DB {
if w.RelationMap == nil {
w.RelationMap = make(map[string]BuilderChainFunc)
}
Expand All @@ -322,57 +324,52 @@ func (w *Wrapper) Relation(name string, fn BuilderChainFunc) *Wrapper {
}

//Use is change database
func Use(db string) *Wrapper {
return &Wrapper{database: db}
}

//WithTx use the specified transaction session
func WithTx(tx *sqlx.Tx) *Wrapper {
return &Wrapper{tx: tx}
func Use(db string) *DB {
return &DB{database: Sqlx(db)}
}

//Exec default database
func Exec(query string, args ...interface{}) (sql.Result, error) {
return defaultWrapper.Exec(query, args...)
return Use(defaultLink).Exec(query, args...)
}

//Queryx default database
func Queryx(query string, args ...interface{}) (*sqlx.Rows, error) {
return defaultWrapper.Queryx(query, args...)
return Use(defaultLink).Queryx(query, args...)
}

//QueryRowx default database
func QueryRowx(query string, args ...interface{}) *sqlx.Row {
return defaultWrapper.QueryRowx(query, args...)
return Use(defaultLink).QueryRowx(query, args...)
}

//Txx default database the transaction with context
func Txx(ctx context.Context, fn func(ctx context.Context, tx *sqlx.Tx) error) error {
return defaultWrapper.Txx(ctx, fn)
func Txx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) error {
return Use(defaultLink).Txx(ctx, fn)
}

//Tx default database the transaction
func Tx(fn func(tx *sqlx.Tx) error) error {
return defaultWrapper.Tx(fn)
func Tx(fn func(tx *DB) error) error {
return Use(defaultLink).Tx(fn)
}

//Get default database
func Get(dest interface{}, query string, args ...interface{}) error {
return defaultWrapper.Get(dest, query, args...)
return Use(defaultLink).Get(dest, query, args...)
}

//Select default database
func Select(dest interface{}, query string, args ...interface{}) error {
return defaultWrapper.Select(dest, query, args...)
return Use(defaultLink).Select(dest, query, args...)
}

// Import SQL DDL from io.Reader
func Import(f string) ([]sql.Result, error) {
return defaultWrapper.Import(f)
return Use(defaultLink).Import(f)
}

// Relation association table builder handle
func Relation(name string, fn BuilderChainFunc) *Wrapper {
func Relation(name string, fn BuilderChainFunc) *DB {
w := Use(defaultLink)
w.RelationMap = make(map[string]BuilderChainFunc)
w.RelationMap[name] = fn
Expand All @@ -382,5 +379,4 @@ func Relation(name string, fn BuilderChainFunc) *Wrapper {
// SetDefaultLink set default link name
func SetDefaultLink(db string) {
defaultLink = db
defaultWrapper = Use(defaultLink)
}
Loading

0 comments on commit beb871f

Please sign in to comment.