diff --git a/APIDESIGN.md b/APIDESIGN.md index ce447a8..236e006 100644 --- a/APIDESIGN.md +++ b/APIDESIGN.md @@ -23,7 +23,8 @@ gosql.Use(db).Model(&Users{}}).Get() ## Transaction context switching ```go -gosql.WithTx(tx *sqlx.Tx) -gosql.WithTx(tx).Table("xxxx").Where("id = ?",1).Get(&user) -gosql.WithTx(tx).Model(&Users{}).Get() +gosql.Tx(func(tx *gosql.DB){ + tx.Table("xxxx").Where("id = ?",1).Get(&user) + tx.Model(&Users{}).Get() +}) ``` diff --git a/README.md b/README.md index 66bb228..2fccf45 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ The package based on [sqlx](https://github.com/jmoiron/sqlx), It's simple and ke - Remove Model interface DbName() function,use the Use() function - Uniform API design specification, see [APIDESIGN](APIDESIGN.md) - Relation add `connection:"db2"` struct tag, Solve the cross-library connection problem caused by deleting DbName() +- Discard the WithTx function ## Usage @@ -39,8 +40,7 @@ func main(){ //connection database gosql.Connect(configs) - - gosql.DB().QueryRowx("select * from users where id = 1") + gosql.QueryRowx("select * from users where id = 1") } ``` @@ -156,7 +156,7 @@ gosql.Model(&user).Get("status") The `Tx` function has a callback function, if an error is returned, the transaction rollback ```go -gosql.Tx(func(tx *sqlx.Tx) error { +gosql.Tx(func(tx *gosql.DB) error { for id := 1; id < 10; id++ { user := &Users{ Id: id, @@ -164,8 +164,8 @@ gosql.Tx(func(tx *sqlx.Tx) error { Email: "test" + strconv.Itoa(id) + "@test.com", } - //v2 - gosql.WithTx(tx).Model(user).Create() + //v2 support, do some database operations in the transaction (use 'tx' from this point, not 'gosql') + tx.Model(user).Create() if id == 8 { return errors.New("interrupt the transaction") @@ -174,7 +174,7 @@ gosql.Tx(func(tx *sqlx.Tx) error { //query with transaction var num int - err := gosql.WithTx(tx).QueryRowx("select count(*) from user_id = 1").Scan(&num) + err := tx.QueryRowx("select count(*) from user_id = 1").Scan(&num) if err != nil { return err @@ -234,8 +234,8 @@ gosql.Table("users").Where("id = ?", 1).Count() //Change database gosql.Use("db2").Table("users").Where("id = ?", 1).Count() -//Transaction `tx` is *sqlx.Tx for v2 -gosql.WithTx(tx).Table("users").Where("id = ?", 1}).Count() +//Transaction `tx` +tx.Table("users").Where("id = ?", 1}).Count() ``` diff --git a/connection.go b/connection.go index 2562a0e..3968b7c 100644 --- a/connection.go +++ b/connection.go @@ -18,7 +18,7 @@ var dbService = make(map[string]*sqlx.DB, 0) // DB gets the specified database engine, // or the default DB if no name is specified. -func DB(name ...string) *sqlx.DB { +func Sqlx(name ...string) *sqlx.DB { dbName := defaultLink if name != nil { dbName = name[0] @@ -73,11 +73,10 @@ func Connect(configs map[string]*Config) (err error) { } if db, ok := dbService[key]; ok { - dbService[key] = sess - db.Close() - } else { - dbService[key] = sess + _ = db.Close() } + + dbService[key] = sess } return } diff --git a/connection_test.go b/connection_test.go index a9ee470..b6adc2a 100644 --- a/connection_test.go +++ b/connection_test.go @@ -36,13 +36,13 @@ func TestMain(m *testing.M) { ShowSql: true, } - Connect(configs) + _ = Connect(configs) m.Run() } func TestConnect(t *testing.T) { - db := DB() + db := Sqlx() if db.DriverName() != "mysql" { t.Fatalf("sqlx database connection error") diff --git a/wrapper.go b/db.go similarity index 72% rename from wrapper.go rename to db.go index 31680fb..093a7c7 100644 --- a/wrapper.go +++ b/db.go @@ -21,36 +21,35 @@ type ISqlx interface { Select(dest interface{}, query string, args ...interface{}) error Exec(query string, args ...interface{}) (sql.Result, error) Rebind(query string) string + DriverName() string } -var ( - defaultWrapper = Use(defaultLink) -) - -type BuilderChainFunc func(b *Builder) +type BuilderChainFunc func(b *ModelStruct) -type Wrapper struct { - database string +type DB struct { + database *sqlx.DB tx *sqlx.Tx logging bool RelationMap map[string]BuilderChainFunc } -func (w *Wrapper) db() ISqlx { +// return database instance, if it is a transaction, the transaction priority is higher +func (w *DB) db() ISqlx { if w.tx != nil { return w.tx.Unsafe() } - return DB(w.database).Unsafe() + return w.database.Unsafe() } -func ShowSql() *Wrapper { +// ShowSql single show sql log +func ShowSql() *DB { w := Use(defaultLink) w.logging = true return w } -func (w *Wrapper) argsIn(query string, args []interface{}) (string, []interface{}, error) { +func (w *DB) argsIn(query string, args []interface{}) (string, []interface{}, error) { newArgs := make([]interface{}, 0) newQuery, newArgs, err := sqlx.In(query, args...) @@ -61,13 +60,22 @@ func (w *Wrapper) argsIn(query string, args []interface{}) (string, []interface{ return newQuery, newArgs, nil } +//DriverName wrapper sqlx.DriverName +func (w *DB) DriverName() string { + if w.tx != nil { + return w.tx.DriverName() + } + + return w.database.DriverName() +} + //Rebind wrapper sqlx.Rebind -func (w *Wrapper) Rebind(query string) string { +func (w *DB) Rebind(query string) string { return w.db().Rebind(query) } //Exec wrapper sqlx.Exec -func (w *Wrapper) Exec(query string, args ...interface{}) (result sql.Result, err error) { +func (w *DB) Exec(query string, args ...interface{}) (result sql.Result, err error) { defer func(start time.Time) { logger.Log(&QueryStatus{ Query: query, @@ -83,7 +91,7 @@ func (w *Wrapper) Exec(query string, args ...interface{}) (result sql.Result, er } //Queryx wrapper sqlx.Queryx -func (w *Wrapper) Queryx(query string, args ...interface{}) (rows *sqlx.Rows, err error) { +func (w *DB) Queryx(query string, args ...interface{}) (rows *sqlx.Rows, err error) { defer func(start time.Time) { logger.Log(&QueryStatus{ Query: query, @@ -103,7 +111,7 @@ func (w *Wrapper) Queryx(query string, args ...interface{}) (rows *sqlx.Rows, er } //QueryRowx wrapper sqlx.QueryRowx -func (w *Wrapper) QueryRowx(query string, args ...interface{}) (rows *sqlx.Row) { +func (w *DB) QueryRowx(query string, args ...interface{}) (rows *sqlx.Row) { defer func(start time.Time) { logger.Log(&QueryStatus{ Query: query, @@ -120,7 +128,7 @@ func (w *Wrapper) QueryRowx(query string, args ...interface{}) (rows *sqlx.Row) } //Get wrapper sqlx.Get -func (w *Wrapper) Get(dest interface{}, query string, args ...interface{}) (err error) { +func (w *DB) Get(dest interface{}, query string, args ...interface{}) (err error) { defer func(start time.Time) { logger.Log(&QueryStatus{ Query: query, @@ -170,7 +178,7 @@ func indirectType(v reflect.Type) reflect.Type { } //Select wrapper sqlx.Select -func (w *Wrapper) Select(dest interface{}, query string, args ...interface{}) (err error) { +func (w *DB) Select(dest interface{}, query string, args ...interface{}) (err error) { defer func(start time.Time) { logger.Log(&QueryStatus{ Query: query, @@ -207,13 +215,12 @@ func (w *Wrapper) Select(dest interface{}, query string, args ...interface{}) (e } //Txx the transaction with context -func (w *Wrapper) Txx(ctx context.Context, fn func(ctx context.Context, tx *sqlx.Tx) error) (err error) { - db := DB(w.database) - tx, err := db.BeginTxx(ctx, nil) +func (w *DB) Txx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) (err error) { + tx, err := w.database.BeginTxx(ctx, nil) + if err != nil { return err } - tx = tx.Unsafe() defer func() { if err != nil { err := tx.Rollback() @@ -223,7 +230,7 @@ func (w *Wrapper) Txx(ctx context.Context, fn func(ctx context.Context, tx *sqlx } }() - err = fn(ctx, tx) + err = fn(ctx, &DB{tx: tx}) if err == nil { err = tx.Commit() } @@ -231,13 +238,11 @@ func (w *Wrapper) Txx(ctx context.Context, fn func(ctx context.Context, tx *sqlx } //Tx the transaction -func (w *Wrapper) Tx(fn func(tx *sqlx.Tx) error) (err error) { - db := DB(w.database) - tx, err := db.Beginx() +func (w *DB) Tx(fn func(w *DB) error) (err error) { + tx, err := w.database.Beginx() if err != nil { return err } - tx = tx.Unsafe() defer func() { if err != nil { err := tx.Rollback() @@ -246,8 +251,7 @@ func (w *Wrapper) Tx(fn func(tx *sqlx.Tx) error) (err error) { } } }() - - err = fn(tx) + err = fn(&DB{tx: tx}) if err == nil { err = tx.Commit() } @@ -257,21 +261,19 @@ func (w *Wrapper) Tx(fn func(tx *sqlx.Tx) error) (err error) { // Table database handler from to table name // for example: // gosql.Use("db2").Table("users") -// gosql.WithTx(tx).Table("users").Get() -func (w *Wrapper) Table(t string) *Mapper { - return &Mapper{wrapper: w, SQLBuilder: SQLBuilder{table: t}} +func (w *DB) Table(t string) *Mapper { + return &Mapper{db: w, SQLBuilder: SQLBuilder{table: t, dialect: newDialect(w.DriverName())}} } //Model database handler from to struct //for example: // gosql.Use("db2").Model(&users{}) -// gosql.WithTx(tx).Model(&users{}).Get() -func (w *Wrapper) Model(m interface{}) *Builder { - return &Builder{model: m, wrapper: w} +func (w *DB) Model(m interface{}) *ModelStruct { + return &ModelStruct{model: m, db: w, SQLBuilder: SQLBuilder{dialect: newDialect(w.DriverName())}} } //Import SQL DDL from sql file -func (w *Wrapper) Import(f string) ([]sql.Result, error) { +func (w *DB) Import(f string) ([]sql.Result, error) { file, err := os.Open(f) if err != nil { return nil, err @@ -313,7 +315,7 @@ func (w *Wrapper) Import(f string) ([]sql.Result, error) { } // Relation association table builder handle -func (w *Wrapper) Relation(name string, fn BuilderChainFunc) *Wrapper { +func (w *DB) Relation(name string, fn BuilderChainFunc) *DB { if w.RelationMap == nil { w.RelationMap = make(map[string]BuilderChainFunc) } @@ -322,57 +324,52 @@ func (w *Wrapper) Relation(name string, fn BuilderChainFunc) *Wrapper { } //Use is change database -func Use(db string) *Wrapper { - return &Wrapper{database: db} -} - -//WithTx use the specified transaction session -func WithTx(tx *sqlx.Tx) *Wrapper { - return &Wrapper{tx: tx} +func Use(db string) *DB { + return &DB{database: Sqlx(db)} } //Exec default database func Exec(query string, args ...interface{}) (sql.Result, error) { - return defaultWrapper.Exec(query, args...) + return Use(defaultLink).Exec(query, args...) } //Queryx default database func Queryx(query string, args ...interface{}) (*sqlx.Rows, error) { - return defaultWrapper.Queryx(query, args...) + return Use(defaultLink).Queryx(query, args...) } //QueryRowx default database func QueryRowx(query string, args ...interface{}) *sqlx.Row { - return defaultWrapper.QueryRowx(query, args...) + return Use(defaultLink).QueryRowx(query, args...) } //Txx default database the transaction with context -func Txx(ctx context.Context, fn func(ctx context.Context, tx *sqlx.Tx) error) error { - return defaultWrapper.Txx(ctx, fn) +func Txx(ctx context.Context, fn func(ctx context.Context, tx *DB) error) error { + return Use(defaultLink).Txx(ctx, fn) } //Tx default database the transaction -func Tx(fn func(tx *sqlx.Tx) error) error { - return defaultWrapper.Tx(fn) +func Tx(fn func(tx *DB) error) error { + return Use(defaultLink).Tx(fn) } //Get default database func Get(dest interface{}, query string, args ...interface{}) error { - return defaultWrapper.Get(dest, query, args...) + return Use(defaultLink).Get(dest, query, args...) } //Select default database func Select(dest interface{}, query string, args ...interface{}) error { - return defaultWrapper.Select(dest, query, args...) + return Use(defaultLink).Select(dest, query, args...) } // Import SQL DDL from io.Reader func Import(f string) ([]sql.Result, error) { - return defaultWrapper.Import(f) + return Use(defaultLink).Import(f) } // Relation association table builder handle -func Relation(name string, fn BuilderChainFunc) *Wrapper { +func Relation(name string, fn BuilderChainFunc) *DB { w := Use(defaultLink) w.RelationMap = make(map[string]BuilderChainFunc) w.RelationMap[name] = fn @@ -382,5 +379,4 @@ func Relation(name string, fn BuilderChainFunc) *Wrapper { // SetDefaultLink set default link name func SetDefaultLink(db string) { defaultLink = db - defaultWrapper = Use(defaultLink) } diff --git a/wrapper_test.go b/db_test.go similarity index 91% rename from wrapper_test.go rename to db_test.go index fd74952..2f09554 100644 --- a/wrapper_test.go +++ b/db_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/jmoiron/sqlx" - "github.com/ilibs/gosql/v2/internal/example/models" ) @@ -262,14 +260,14 @@ func TestTx(t *testing.T) { RunWithSchema(t, func(t *testing.T) { //1 { - Tx(func(tx *sqlx.Tx) error { + Tx(func(tx *DB) error { for id := 1; id < 10; id++ { user := &models.Users{ Id: id, Name: "test" + strconv.Itoa(id), } - WithTx(tx).Model(user).Create() + tx.Model(user).Create() if id == 8 { return errors.New("simulation terminated") @@ -292,14 +290,14 @@ func TestTx(t *testing.T) { //2 { - Tx(func(tx *sqlx.Tx) error { + Tx(func(tx *DB) error { for id := 1; id < 10; id++ { user := &models.Users{ Id: id, Name: "test" + strconv.Itoa(id), } - WithTx(tx).Model(user).Create() + tx.Model(user).Create() } return nil @@ -321,16 +319,16 @@ func TestTx(t *testing.T) { func TestWithTx(t *testing.T) { RunWithSchema(t, func(t *testing.T) { { - Tx(func(tx *sqlx.Tx) error { + Tx(func(tx *DB) error { for id := 1; id < 10; id++ { - _, err := WithTx(tx).Exec("INSERT INTO users(id,name,created_at,updated_at) VALUES(?,?,?,?,?)", id, "test"+strconv.Itoa(id), time.Now(), time.Now()) + _, err := tx.Exec("INSERT INTO users(id,name,created_at,updated_at) VALUES(?,?,?,?,?)", id, "test"+strconv.Itoa(id), time.Now(), time.Now()) if err != nil { return err } } var num int - err := WithTx(tx).QueryRowx("select count(*) from users").Scan(&num) + err := tx.QueryRowx("select count(*) from users").Scan(&num) if err != nil { return err @@ -352,14 +350,14 @@ func TestTxx(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - Txx(ctx, func(ctx context.Context, tx *sqlx.Tx) error { + Txx(ctx, func(ctx context.Context, tx *DB) error { for id := 1; id < 10; id++ { user := &models.Users{ Id: id, Name: "test" + strconv.Itoa(id), } - WithTx(tx).Model(user).Create() + tx.Model(user).Create() if id == 8 { cancel() @@ -386,7 +384,7 @@ func TestWrapper_Relation(t *testing.T) { RunWithSchema(t, func(t *testing.T) { initDatas(t) moment := &MomentList{} - err := Relation("User", func(b *Builder) { + err := Relation("User", func(b *ModelStruct) { b.Where("status = 0") }).Get(moment, "select * from moments") @@ -403,7 +401,7 @@ func TestWrapper_Relation2(t *testing.T) { RunWithSchema(t, func(t *testing.T) { initDatas(t) var moments = make([]*MomentList, 0) - err := Relation("User", func(b *Builder) { + err := Relation("User", func(b *ModelStruct) { b.Where("status = 1") }).Select(&moments, "select * from moments") diff --git a/dialect.go b/dialect.go new file mode 100644 index 0000000..72ea094 --- /dev/null +++ b/dialect.go @@ -0,0 +1,55 @@ +package gosql + +import ( + "fmt" +) + +// Dialect interface contains behaviors that differ across SQL database +type Dialect interface { + // GetName get dialect's name + GetName() string + + // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name + Quote(key string) string +} + +type commonDialect struct { +} + +func (commonDialect) GetName() string { + return "common" +} + +func (commonDialect) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) +} + +var dialectsMap = map[string]Dialect{} + +// RegisterDialect register new dialect +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// GetDialect gets the dialect for the specified dialect name +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} + +func mustGetDialect(name string) Dialect { + if dialect, ok := dialectsMap[name]; ok { + return dialect + } + panic(fmt.Sprintf("`%v` is not officially supported", name)) + return nil +} + +func newDialect(name string) Dialect { + if value, ok := GetDialect(name); ok { + return value + } + + fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) + return &commonDialect{} +} diff --git a/dialect_mysql.go b/dialect_mysql.go new file mode 100644 index 0000000..43da4db --- /dev/null +++ b/dialect_mysql.go @@ -0,0 +1,21 @@ +package gosql + +import ( + "fmt" +) + +type mysqlDialect struct { + commonDialect +} + +func init() { + RegisterDialect("mysql", &mysqlDialect{}) +} + +func (mysqlDialect) GetName() string { + return "mysql" +} + +func (mysqlDialect) Quote(key string) string { + return fmt.Sprintf("`%s`", key) +} diff --git a/dialect_postgres.go b/dialect_postgres.go new file mode 100644 index 0000000..2884fb1 --- /dev/null +++ b/dialect_postgres.go @@ -0,0 +1,13 @@ +package gosql + +type postgresDialect struct { + commonDialect +} + +func init() { + RegisterDialect("postgres", &postgresDialect{}) +} + +func (postgresDialect) GetName() string { + return "postgres" +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go new file mode 100644 index 0000000..d5d245a --- /dev/null +++ b/dialect_sqlite3.go @@ -0,0 +1,13 @@ +package gosql + +type sqlite3Dialect struct { + commonDialect +} + +func init() { + RegisterDialect("sqlite3", &sqlite3Dialect{}) +} + +func (sqlite3Dialect) GetName() string { + return "sqlite3" +} diff --git a/expr_test.go b/expr_test.go index af63dea..ea5bdeb 100644 --- a/expr_test.go +++ b/expr_test.go @@ -6,7 +6,8 @@ import ( func TestExpr(t *testing.T) { b := &SQLBuilder{ - table: "users", + table: "users", + dialect: mustGetDialect("mysql"), } q := b.updateString(map[string]interface{}{ diff --git a/hook.go b/hook.go index e9c1011..bc36466 100644 --- a/hook.go +++ b/hook.go @@ -6,18 +6,16 @@ import ( "log" "reflect" "strings" - - "github.com/jmoiron/sqlx" ) type Hook struct { - wrapper *Wrapper - Errs []error + db *DB + Errs []error } -func NewHook(wrapper *Wrapper) *Hook { +func NewHook(db *DB) *Hook { return &Hook{ - wrapper: wrapper, + db: db, } } @@ -33,10 +31,10 @@ func (h *Hook) callMethod(methodName string, reflectValue reflect.Value) { method() case func() error: h.Err(method()) - case func(*sqlx.Tx): - method(h.wrapper.tx) - case func(*sqlx.Tx) error: - h.Err(method(h.wrapper.tx)) + case func(db *DB): + method(h.db) + case func(db *DB) error: + h.Err(method(h.db)) default: log.Fatal(fmt.Errorf("unsupported function %v", methodName)) } diff --git a/mapper.go b/mapper.go index 9065997..6c8f7e9 100644 --- a/mapper.go +++ b/mapper.go @@ -1,17 +1,19 @@ package gosql type Mapper struct { - wrapper *Wrapper + db *DB SQLBuilder } -func (m *Mapper) ShowSQL() *Mapper { - m.wrapper.logging = true - return m +// Table select table name +func Table(t string) *Mapper { + db := &DB{database: Sqlx(defaultLink)} + return &Mapper{db: db, SQLBuilder: SQLBuilder{table: t, dialect: newDialect(db.DriverName())}} } -func (m *Mapper) db() ISqlx { - return m.wrapper +func (m *Mapper) ShowSQL() *Mapper { + m.db.logging = true + return m } //Where @@ -22,7 +24,7 @@ func (m *Mapper) Where(str string, args ...interface{}) *Mapper { //Update data from to map[string]interface func (m *Mapper) Update(data map[string]interface{}) (affected int64, err error) { - result, err := m.db().Exec(m.updateString(data), m.args...) + result, err := m.db.Exec(m.updateString(data), m.args...) if err != nil { return 0, err } @@ -32,7 +34,7 @@ func (m *Mapper) Update(data map[string]interface{}) (affected int64, err error) //Create data from to map[string]interface func (m *Mapper) Create(data map[string]interface{}) (lastInsertId int64, err error) { - result, err := m.db().Exec(m.insertString(data), m.args...) + result, err := m.db.Exec(m.insertString(data), m.args...) if err != nil { return 0, err } @@ -42,7 +44,7 @@ func (m *Mapper) Create(data map[string]interface{}) (lastInsertId int64, err er //Delete data from to map[string]interface func (m *Mapper) Delete() (affected int64, err error) { - result, err := m.db().Exec(m.deleteString(), m.args...) + result, err := m.db.Exec(m.deleteString(), m.args...) if err != nil { return 0, err } @@ -52,11 +54,6 @@ func (m *Mapper) Delete() (affected int64, err error) { //Count data from to map[string]interface func (m *Mapper) Count() (num int64, err error) { - err = m.db().Get(&num, m.countString(), m.args...) + err = m.db.Get(&num, m.countString(), m.args...) return num, err } - -// Table select table name -func Table(t string) *Mapper { - return &Mapper{wrapper: &Wrapper{database: defaultLink}, SQLBuilder: SQLBuilder{table: t}} -} diff --git a/builder.go b/model.go similarity index 71% rename from builder.go rename to model.go index 7c480a7..0d2ea8f 100644 --- a/builder.go +++ b/model.go @@ -5,8 +5,6 @@ import ( "log" "reflect" "strconv" - - "github.com/jmoiron/sqlx" ) var ( @@ -34,37 +32,34 @@ type IModel interface { PK() string } -type Builder struct { +type ModelStruct struct { model interface{} modelReflectValue reflect.Value + modelEntity IModel + db *DB SQLBuilder - modelEntity IModel - wrapper *Wrapper } // Model construct SQL from Struct -func Model(model interface{}) *Builder { - return &Builder{ - model: model, - wrapper: &Wrapper{database: defaultLink}, +func Model(model interface{}) *ModelStruct { + return &ModelStruct{ + model: model, + db: &DB{database: Sqlx(defaultLink)}, } } // ShowSQL output single sql -func (b *Builder) ShowSQL() *Builder { - b.wrapper.logging = true +func (b *ModelStruct) ShowSQL() *ModelStruct { + b.db.logging = true return b } -func (b *Builder) db() ISqlx { - return b.wrapper -} - -func (b *Builder) initModel() { +func (b *ModelStruct) initModel() { if m, ok := b.model.(IModel); ok { b.modelEntity = m b.table = m.TableName() b.modelReflectValue = reflect.ValueOf(m) + b.dialect = newDialect(b.db.DriverName()) } else { value := reflect.ValueOf(b.model) if value.Kind() != reflect.Ptr { @@ -105,60 +100,56 @@ func (b *Builder) initModel() { b.modelEntity = m b.table = m.TableName() b.modelReflectValue = reflect.ValueOf(m) + b.dialect = newDialect(b.db.DriverName()) } else { log.Fatalf("model argument must implementation IModel interface or slice []IModel and pointer,but get %#v", b.model) } } } -//WithTx model use tx -func (b *Builder) WithTx(tx *sqlx.Tx) *Builder { - b.wrapper.tx = tx - return b -} - //Hint is set TDDL "/*+TDDL:slave()*/" -func (b *Builder) Hint(hint string) *Builder { +func (b *ModelStruct) Hint(hint string) *ModelStruct { b.hint = hint return b } //ForceIndex -func (b *Builder) ForceIndex(i string) *Builder { +func (b *ModelStruct) ForceIndex(i string) *ModelStruct { b.forceIndex = i return b } //Where for example Where("id = ? and name = ?",1,"test") -func (b *Builder) Where(str string, args ...interface{}) *Builder { +func (b *ModelStruct) Where(str string, args ...interface{}) *ModelStruct { b.SQLBuilder.Where(str, args...) return b } -func (b *Builder) Select(fields string) *Builder { +// Select filter column +func (b *ModelStruct) Select(fields string) *ModelStruct { b.fields = fields return b } //Limit -func (b *Builder) Limit(i int) *Builder { +func (b *ModelStruct) Limit(i int) *ModelStruct { b.limit = strconv.Itoa(i) return b } //Offset -func (b *Builder) Offset(i int) *Builder { +func (b *ModelStruct) Offset(i int) *ModelStruct { b.offset = strconv.Itoa(i) return b } //OrderBy for example "id desc" -func (b *Builder) OrderBy(str string) *Builder { +func (b *ModelStruct) OrderBy(str string) *ModelStruct { b.order = str return b } -func (b *Builder) reflectModel(autoTime []string) map[string]reflect.Value { +func (b *ModelStruct) reflectModel(autoTime []string) map[string]reflect.Value { fields := mapper.FieldMap(b.modelReflectValue) if autoTime != nil { structAutoTime(fields, autoTime) @@ -167,34 +158,34 @@ func (b *Builder) reflectModel(autoTime []string) map[string]reflect.Value { } // Relation association table builder handle -func (b *Builder) Relation(fieldName string, fn BuilderChainFunc) *Builder { - if b.wrapper.RelationMap == nil { - b.wrapper.RelationMap = make(map[string]BuilderChainFunc) +func (b *ModelStruct) Relation(fieldName string, fn BuilderChainFunc) *ModelStruct { + if b.db.RelationMap == nil { + b.db.RelationMap = make(map[string]BuilderChainFunc) } - b.wrapper.RelationMap[fieldName] = fn + b.db.RelationMap[fieldName] = fn return b } //All get data row from to Struct -func (b *Builder) Get(zeroValues ...string) (err error) { +func (b *ModelStruct) Get(zeroValues ...string) (err error) { b.initModel() m := zeroValueFilter(b.reflectModel(nil), zeroValues) //If where is empty, the primary key where condition is generated automatically b.generateWhere(m) - return b.db().Get(b.model, b.queryString(), b.args...) + return b.db.Get(b.model, b.queryString(), b.args...) } //All get data rows from to Struct -func (b *Builder) All() (err error) { +func (b *ModelStruct) All() (err error) { b.initModel() - return b.db().Select(b.model, b.queryString(), b.args...) + return b.db.Select(b.model, b.queryString(), b.args...) } //Create data from to Struct -func (b *Builder) Create() (lastInsertId int64, err error) { +func (b *ModelStruct) Create() (lastInsertId int64, err error) { b.initModel() - hook := NewHook(b.wrapper) + hook := NewHook(b.db) hook.callMethod("BeforeChange", b.modelReflectValue) hook.callMethod("BeforeCreate", b.modelReflectValue) if hook.HasError() > 0 { @@ -204,7 +195,7 @@ func (b *Builder) Create() (lastInsertId int64, err error) { fields := b.reflectModel(AUTO_CREATE_TIME_FIELDS) m := structToMap(fields) - result, err := b.db().Exec(b.insertString(m), b.args...) + result, err := b.db.Exec(b.insertString(m), b.args...) if err != nil { return 0, err } @@ -229,13 +220,13 @@ func (b *Builder) Create() (lastInsertId int64, err error) { return lastId, err } -func (b *Builder) generateWhere(m map[string]interface{}) { +func (b *ModelStruct) generateWhere(m map[string]interface{}) { for k, v := range m { b.Where(fmt.Sprintf("%s=?", k), v) } } -func (b *Builder) generateWhereForPK(m map[string]interface{}) { +func (b *ModelStruct) generateWhereForPK(m map[string]interface{}) { pk := b.model.(IModel).PK() pval, has := m[pk] if b.where == "" && has { @@ -245,9 +236,9 @@ func (b *Builder) generateWhereForPK(m map[string]interface{}) { } //gosql.Model(&User{Id:1,Status:0}).Update("status") -func (b *Builder) Update(zeroValues ...string) (affected int64, err error) { +func (b *ModelStruct) Update(zeroValues ...string) (affected int64, err error) { b.initModel() - hook := NewHook(b.wrapper) + hook := NewHook(b.db) hook.callMethod("BeforeChange", b.modelReflectValue) hook.callMethod("BeforeUpdate", b.modelReflectValue) if hook.HasError() > 0 { @@ -260,7 +251,7 @@ func (b *Builder) Update(zeroValues ...string) (affected int64, err error) { //If where is empty, the primary key where condition is generated automatically b.generateWhereForPK(m) - result, err := b.db().Exec(b.updateString(m), b.args...) + result, err := b.db.Exec(b.updateString(m), b.args...) if err != nil { return 0, err } @@ -276,9 +267,9 @@ func (b *Builder) Update(zeroValues ...string) (affected int64, err error) { } //gosql.Model(&User{Id:1}).Delete() -func (b *Builder) Delete(zeroValues ...string) (affected int64, err error) { +func (b *ModelStruct) Delete(zeroValues ...string) (affected int64, err error) { b.initModel() - hook := NewHook(b.wrapper) + hook := NewHook(b.db) hook.callMethod("BeforeChange", b.modelReflectValue) hook.callMethod("BeforeDelete", b.modelReflectValue) if hook.HasError() > 0 { @@ -289,7 +280,7 @@ func (b *Builder) Delete(zeroValues ...string) (affected int64, err error) { //If where is empty, the primary key where condition is generated automatically b.generateWhere(m) - result, err := b.db().Exec(b.deleteString(), b.args...) + result, err := b.db.Exec(b.deleteString(), b.args...) if err != nil { return 0, err } @@ -305,13 +296,13 @@ func (b *Builder) Delete(zeroValues ...string) (affected int64, err error) { } //gosql.Model(&User{}).Where("status = 0").Count() -func (b *Builder) Count(zeroValues ...string) (num int64, err error) { +func (b *ModelStruct) Count(zeroValues ...string) (num int64, err error) { b.initModel() m := zeroValueFilter(b.reflectModel(nil), zeroValues) //If where is empty, the primary key where condition is generated automatically b.generateWhere(m) - err = b.db().Get(&num, b.countString(), b.args...) + err = b.db.Get(&num, b.countString(), b.args...) return num, err } diff --git a/builder_test.go b/model_test.go similarity index 98% rename from builder_test.go rename to model_test.go index aac44fa..9f4a68f 100644 --- a/builder_test.go +++ b/model_test.go @@ -93,8 +93,8 @@ VALUES ) func RunWithSchema(t *testing.T, test func(t *testing.T)) { - db := DB() - db2 := DB("db2") + db := Sqlx() + db2 := Sqlx("db2") defer func() { for k := range createSchemas { _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", k)) @@ -124,8 +124,8 @@ func RunWithSchema(t *testing.T, test func(t *testing.T)) { } func initDatas(t *testing.T) { - db := DB() - db2 := DB("db2") + db := Sqlx() + db2 := Sqlx("db2") for k, v := range datas { udb := db if k == "photos" { @@ -536,7 +536,7 @@ func TestBuilder_Relation1(t *testing.T) { RunWithSchema(t, func(t *testing.T) { initDatas(t) moment := &MomentList{} - err := Model(moment).Relation("User", func(b *Builder) { + err := Model(moment).Relation("User", func(b *ModelStruct) { b.Where("status = 1") }).Where("status = 1 and id = ?", 14).Get() @@ -552,7 +552,7 @@ func TestBuilder_Relation1(t *testing.T) { func TestBuilder_Relation2(t *testing.T) { RunWithSchema(t, func(t *testing.T) { var moments = make([]*MomentList, 0) - err := Model(&moments).Relation("User", func(b *Builder) { + err := Model(&moments).Relation("User", func(b *ModelStruct) { b.Where("status = 0") }).Where("status = 1").Limit(10).All() diff --git a/relation.go b/relation.go index 6992e03..ebfbc67 100644 --- a/relation.go +++ b/relation.go @@ -30,8 +30,8 @@ func eachField(t reflect.Type, fn func(field reflect.StructField, val string, na return nil } -func newModel(value reflect.Value, connection string) *Builder { - var m *Builder +func newModel(value reflect.Value, connection string) *ModelStruct { + var m *ModelStruct if connection != "" { m = Use(connection).Model(value.Interface()) } else { diff --git a/relation_test.go b/relation_test.go index f7bb31c..563b968 100644 --- a/relation_test.go +++ b/relation_test.go @@ -17,7 +17,7 @@ func TestRelationOne2(t *testing.T) { RunWithSchema(t, func(t *testing.T) { initDatas(t) moment := &UserMoment{} - err := Model(moment).Relation("Moments", func(b *Builder) { + err := Model(moment).Relation("Moments", func(b *ModelStruct) { b.Limit(2) }).Where("id = ?", 5).Get() diff --git a/sql_builder.go b/sql_builder.go index 2027df6..d469903 100644 --- a/sql_builder.go +++ b/sql_builder.go @@ -6,6 +6,7 @@ import ( ) type SQLBuilder struct { + dialect Dialect fields string table string forceIndex string @@ -45,7 +46,7 @@ func (s *SQLBuilder) queryString() string { s.fields = "*" } - table := "`" + s.table + "`" + table := s.dialect.Quote(s.table) if s.forceIndex != "" { table += fmt.Sprintf(" force index(%s)", s.forceIndex) } @@ -59,7 +60,7 @@ func (s *SQLBuilder) queryString() string { //countString Assemble the count statement func (s *SQLBuilder) countString() string { - query := fmt.Sprintf("%sSELECT count(*) FROM `%s` %s", s.hint, s.table, s.where) + query := fmt.Sprintf("%sSELECT count(*) FROM %s %s", s.hint, s.dialect.Quote(s.table), s.where) query = strings.TrimRight(query, " ") query = query + ";" @@ -70,12 +71,12 @@ func (s *SQLBuilder) countString() string { func (s *SQLBuilder) insertString(params map[string]interface{}) string { var cols, vals []string for _, k := range sortedParamKeys(params) { - cols = append(cols, fmt.Sprintf("`%s`", k)) + cols = append(cols, s.dialect.Quote(k)) vals = append(vals, "?") s.args = append(s.args, params[k]) } - return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES(%s);", s.table, strings.Join(cols, ","), strings.Join(vals, ",")) + return fmt.Sprintf("INSERT INTO %s (%s) VALUES(%s);", s.dialect.Quote(s.table), strings.Join(cols, ","), strings.Join(vals, ",")) } //updateString Assemble the update statement @@ -85,17 +86,17 @@ func (s *SQLBuilder) updateString(params map[string]interface{}) string { for _, k := range sortedParamKeys(params) { if e, ok := params[k].(*expr); ok { - updateFields = append(updateFields, fmt.Sprintf("%s=%s", fmt.Sprintf("`%s`", k), e.expr)) + updateFields = append(updateFields, fmt.Sprintf("%s=%s", s.dialect.Quote(k), e.expr)) args = append(args, e.args...) } else { - updateFields = append(updateFields, fmt.Sprintf("%s=?", fmt.Sprintf("`%s`", k))) + updateFields = append(updateFields, fmt.Sprintf("%s=?", s.dialect.Quote(k))) args = append(args, params[k]) } } args = append(args, s.args...) s.args = args - query := fmt.Sprintf("UPDATE `%s` SET %s %s", s.table, strings.Join(updateFields, ","), s.where) + query := fmt.Sprintf("UPDATE %s SET %s %s", s.dialect.Quote(s.table), strings.Join(updateFields, ","), s.where) query = strings.TrimRight(query, " ") query = query + ";" return query @@ -103,7 +104,7 @@ func (s *SQLBuilder) updateString(params map[string]interface{}) string { //deleteString Assemble the delete statement func (s *SQLBuilder) deleteString() string { - query := fmt.Sprintf("DELETE FROM `%s` %s", s.table, s.where) + query := fmt.Sprintf("DELETE FROM %s %s", s.dialect.Quote(s.table), s.where) query = strings.TrimRight(query, " ") query = query + ";" return query diff --git a/sql_builder_test.go b/sql_builder_test.go index ea3b686..7799db7 100644 --- a/sql_builder_test.go +++ b/sql_builder_test.go @@ -8,10 +8,11 @@ import ( func TestSQLBuilder_queryString(t *testing.T) { b := &SQLBuilder{ - table: "users", - order: "id desc", - limit: "0", - offset: "10", + dialect: mustGetDialect("mysql"), + table: "users", + order: "id desc", + limit: "0", + offset: "10", } b.Where("id = ?", 1) @@ -25,6 +26,7 @@ func TestSQLBuilder_queryString(t *testing.T) { func TestSQLBuilder_queryForceIndexString(t *testing.T) { b := &SQLBuilder{ + dialect: mustGetDialect("mysql"), table: "users", order: "id desc", forceIndex: "idx_user", @@ -43,7 +45,8 @@ func TestSQLBuilder_queryForceIndexString(t *testing.T) { func TestSQLBuilder_insertString(t *testing.T) { b := &SQLBuilder{ - table: "users", + dialect: mustGetDialect("mysql"), + table: "users", } query := b.insertString(map[string]interface{}{ @@ -62,7 +65,8 @@ func TestSQLBuilder_insertString(t *testing.T) { func TestSQLBuilder_updateString(t *testing.T) { b := &SQLBuilder{ - table: "users", + dialect: mustGetDialect("mysql"), + table: "users", } b.Where("id = ?", 1) @@ -81,7 +85,8 @@ func TestSQLBuilder_updateString(t *testing.T) { func TestSQLBuilder_deleteString(t *testing.T) { b := &SQLBuilder{ - table: "users", + dialect: mustGetDialect("mysql"), + table: "users", } b.Where("id = ?", 1) @@ -96,7 +101,8 @@ func TestSQLBuilder_deleteString(t *testing.T) { func TestSQLBuilder_countString(t *testing.T) { b := &SQLBuilder{ - table: "users", + dialect: mustGetDialect("mysql"), + table: "users", } b.Where("id = ?", 1) @@ -106,3 +112,30 @@ func TestSQLBuilder_countString(t *testing.T) { t.Error("sql builder count error", query) } } + +func TestSQLBuilder_Dialect(t *testing.T) { + testData := map[string]string{ + "mysql": "INSERT INTO `users` (`created_at`,`email`,`id`,`name`,`updated_at`) VALUES(?,?,?,?,?);", + "postgres": `INSERT INTO "users" ("created_at","email","id","name","updated_at") VALUES(?,?,?,?,?);`, + "sqlite3": `INSERT INTO "users" ("created_at","email","id","name","updated_at") VALUES(?,?,?,?,?);`, + } + + for k, v := range testData { + b := &SQLBuilder{ + dialect: mustGetDialect(k), + table: "users", + } + + query := b.insertString(map[string]interface{}{ + "id": 1, + "name": "test", + "email": "test@test.com", + "created_at": "2018-07-11 11:58:21", + "updated_at": "2018-07-11 11:58:21", + }) + + if query != v { + t.Error(fmt.Sprintf("sql builder %s dialect insert error", k), query) + } + } +}