diff --git a/.gitignore b/.gitignore index 117f92f52..9e62d5315 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ documents coverage.txt +coverage*.out _book diff --git a/README.md b/README.md index de594acdc..95ce2d081 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ Because of `ON CONFLICT`. +PostgreSQL not support ON CONFLICT IGNORE (`ON CONFLICT DO NOTHING`) because gorm can not get right `RowsAffected`. + ### GetOrCreate ```go @@ -22,7 +24,7 @@ Because of `ON CONFLICT`. db.Where(User{Name: "jinzhu"}).Attrs(User{Age: 30}).GetOrCreate(&user) ``` -### IGNORE/CreateOrUpdate +### IGNORE/ON CONFLICT UPDATE ```go // mysql: INSERT IGNORE INTO @@ -32,13 +34,27 @@ db.CreateOnConflict(User{UserName: "gorm"}, gorm.IGNORE) // mysql: INSERT INTO ... ON DUPLICATE KEY UPDATE ... db.CreateOnConflict(User{UserName: "gorm"}, User{LastLoginAt: time.Now()}) -// postgresql: INSERT INTO ... ON CONFLICT a_key DO UPDATE ... -db.CreateOnConflict(User{UserName: "gorm"}, "a_key", User{LastLoginAt: time.Now()}) +// postgresql: INSERT INTO ... ON CONFLICT ON CONSTRAINT constraint_name DO UPDATE ... +db.CreateOnConflict(User{UserName: "gorm"}, "constraint_name", User{LastLoginAt: time.Now()}) ``` -### CreateMany/CreateManyOnConflict +### CreateMany/CreateMany OnConflict + +```go +// mysql and sqlite: insert multiple; insert multiple ignore duplicate +// postgresql and mssql do not support ignore +db.CreateMany([]interface{}{&user1, &user2, &user3}, gorm.IGNORE) +db.CreateMany([]interface{}{&user1, &user2, &user3}) + +// mysql: insert on conflict update +db.CreateMany([]interface{}{&user1, &user2, &user3}, &User{UpdatedAt: now}) + +// postgresql: insert on confilct update +db.CreateMany([]interface{}{&user1, &user2, &user3}, 'constraint_name', &User{UpdatedAt: now}) -TODO +// Caution: mssql db driver will not raise error on duplicate +db.CreateMany([]interface{}{&user1, &user1, &user1}) +``` ## License diff --git a/callback_create.go b/callback_create.go index f81f61013..12038f15e 100644 --- a/callback_create.go +++ b/callback_create.go @@ -66,28 +66,66 @@ func createCallback(scope *Scope) { ) // Set columns; Add placeholders and vars for `value_list` - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal && !field.IsIgnored { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) + var ( + columnsString string + placeholdersString string + ) + if values, ok := scope.Get("gorm:create_many"); ok { + // CreateMany + for _, field := range scope.Fields() { + if !field.IsPrimaryKey || !field.IsBlank { + columns = append(columns, field.DBName) + } + } + createMany := values.([](map[string]interface{})) + var placeholdersStrings []string + firstObjLength := len(createMany[0]) + for _, obj := range createMany { + if len(obj) != firstObjLength { + scope.Err(errors.New("createMany objects should have the same fields")) + return + } + placeholders = []string{} + for _, column := range columns { + if fieldValue, ok := obj[column]; ok { + placeholders = append(placeholders, scope.AddToVars(fieldValue)) + } else { + field, _ := scope.FieldByName(column) placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) + } + placeholdersStrings = append(placeholdersStrings, "("+strings.Join(placeholders, ",")+")") + } + for index, column := range columns { + columns[index] = scope.Quote(column) + } + columnsString = strings.Join(columns, ",") + placeholdersString = strings.Join(placeholdersStrings, ",") + } else { + // Normal + for _, field := range scope.Fields() { + if scope.changeableField(field) { + if field.IsNormal && !field.IsIgnored { + if field.IsBlank && field.HasDefaultValue { + blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) + scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) + } else if !field.IsPrimaryKey || !field.IsBlank { + columns = append(columns, scope.Quote(field.DBName)) + placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) + } + } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { + for _, foreignKey := range field.Relationship.ForeignDBNames { + if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { + columns = append(columns, scope.Quote(foreignField.DBName)) + placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) + } } } } } + columnsString = strings.Join(columns, ",") + placeholdersString = "(" + strings.Join(placeholders, ",") + ")" } - columnsString := strings.Join(columns, ",") - placeholdersString := "(" + strings.Join(placeholders, ",") + ")" var ( returningColumn = "*" @@ -102,6 +140,7 @@ func createCallback(scope *Scope) { insertStr, ok := scope.Get("gorm:insert_option") if !ok { scope.Err(errors.New("gorm:insert_option not found")) + return } updateMap := obj.(map[string]interface{}) updateColumns := []string{} @@ -198,6 +237,9 @@ func createCallback(scope *Scope) { if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { primaryField.IsBlank = false scope.db.RowsAffected = 1 + if values, ok := scope.Get("gorm:create_many"); ok { + scope.db.RowsAffected = int64(len(values.([](map[string]interface{})))) + } } } else { scope.Err(ErrUnaddressable) diff --git a/create_test.go b/create_test.go index 4b59a9b50..b1864e100 100644 --- a/create_test.go +++ b/create_test.go @@ -384,3 +384,121 @@ func TestGetOrCreate(t *testing.T) { t.Error("Expected should be NotFound, but got", res.Error) } } + +func TestCreateMany(t *testing.T) { + DB.Delete(&Email{}) + if res := DB.CreateMany([]interface{}{ + &Email{UserId: 1, Email: "jeff@qq.com"}, + &Email{UserId: 2, Email: "alan@qq.com"}, + &Email{UserId: 3, Email: "alice@qq.com"}, + &Email{UserId: 4, Email: "bob@qq.com"}, + }); res.Error != nil || res.RowsAffected != 4 { + t.Error(res.Error, "OR RowsAffected should be 4, got %d", res.RowsAffected) + } + emails := []Email{} + DB.Model(&Email{}).Find(&emails) + if len(emails) != 4 { + t.Error("count should be 4") + } + createAt := time.Time{} + userIds := [4]int{} + emailEmails := [4]string{} + for index, email := range emails { + if createAt.Equal(time.Time{}) { + createAt = email.CreatedAt + } else if !createAt.Equal(email.CreatedAt) { + t.Errorf("%s, %s", email.CreatedAt, createAt) + } + t.Log(email) + userIds[index] = email.UserId + emailEmails[index] = email.Email + } + if userIds != [4]int{1, 2, 3, 4} { + t.Errorf("expected []int{1,2,3,4}, but got %d", userIds) + } + if emailEmails != [4]string{"jeff@qq.com", "alan@qq.com", "alice@qq.com", "bob@qq.com"} { + t.Errorf("%v", emailEmails) + } + DB.Delete(&Email{}) + + // fields not same + if res := DB.CreateMany([]interface{}{ + &Email{Id: 1, UserId: 1, Email: "jeff@qq.com"}, + &Email{UserId: 2, Email: "alan@qq.com"}, + &Email{UserId: 3, Email: "alice@qq.com"}, + &Email{UserId: 4, Email: "bob@qq.com"}, + }); res.Error.Error() != "createMany objects should have the same fields" { + t.Error("Expected error createMany objects should have the same fields, got", res.Error.Error()) + } + + // OnConflict IGNORE + if DB.Dialect().GetName() == "mssql" { + t.Log("mssqldb driver do not raise error when insert many on conflict") + } else { + // Duplicate + if res := DB.CreateMany([]interface{}{ + &Email{Id: 1, UserId: 1, Email: "jeff@qq.com"}, + &Email{Id: 1, UserId: 2, Email: "alan@qq.com"}, + &Email{Id: 1, UserId: 3, Email: "alice@qq.com"}, + &Email{Id: 1, UserId: 4, Email: "bob@qq.com"}, + }); res.Error == nil { + t.Error("Expected error Duplicate entry, but got nil") + } + + } + if DB.Dialect().GetName() == "mysql" || DB.Dialect().GetName() == "sqlite3" { + // Ignore + if res := DB.CreateMany([]interface{}{ + &Email{Id: 1, UserId: 1, Email: "jeff@qq.com"}, + &Email{Id: 1, UserId: 2, Email: "alan@qq.com"}, + &Email{Id: 1, UserId: 3, Email: "alice@qq.com"}, + &Email{Id: 1, UserId: 4, Email: "bob@qq.com"}, + }, "IGNORE"); res.Error != nil { + t.Error(res.Error) + } + emails = []Email{} + DB.Model(&Email{}).Find(&emails) + if len(emails) != 1 { + t.Error("count should be 1, but got", len(emails)) + } + } + + // OnConflict UPDATE + DB.Delete(&Email{}) + if DB.Dialect().GetName() == "postgres" || DB.Dialect().GetName() == "mysql" { + DB.Create(&Email{Id: 1, UserId: 10086, Email: "hello@example.com"}) + + var updateOrIgnore []interface{} + if DB.Dialect().GetName() == "postgres" { + updateOrIgnore = []interface{}{"emails_pkey", &Email{Id: 100, UserId: 10010}} + } else { + updateOrIgnore = []interface{}{&Email{Id: 100, UserId: 10010}} + } + if res := DB.CreateMany([]interface{}{ + &Email{Id: 1, UserId: 10086, Email: "hello@example.com"}, // Duplicate + &Email{Id: 2, UserId: 10086, Email: "hello@example.com"}, // Normal + }, updateOrIgnore...); res.Error != nil { + t.Error(res.Error) + } + + emails := []Email{} + if DB.Model(&Email{}).Find(&emails); len(emails) != 2 { + t.Error("Expected 2 emails, got", len(emails)) + } + emails2 := emails[0] + emails100 := emails[1] + if emails2.Id != 2 { + emails2 = emails[1] + emails100 = emails[0] + } + if emails2.Id != 2 || emails2.UserId != 10086 { + // insert success + t.Error("Expected email 2 user ID 10086 , got", emails[0].Id, emails[0].UserId) + } + if emails100.Id != 100 || emails100.UserId != 10010 { + // update success + t.Error("Expected email 100 user ID 10010 , got", emails[1].Id, emails[1].UserId) + } + } + DB.Delete(&Email{}) +} diff --git a/main.go b/main.go index 265dc830c..4c7e5d768 100644 --- a/main.go +++ b/main.go @@ -506,21 +506,27 @@ func (s *DB) Create(value interface{}) *DB { // db.CreateOnConflict(User{UserName: "gorm"}, "key", User{LastLoginAt: time.Now()}) // INSERT INTO ... ON CONFLICT key DO UPDATE last_login_at = ... func (s *DB) CreateOnConflict(value interface{}, updateOrIgnore ...interface{}) *DB { scope := s.NewScope(value) - insertMod, updateStr, updateObj := scope.Dialect().OnConflict(updateOrIgnore...) - if insertMod == "" && updateStr == "" { + if !scope.onConflict(updateOrIgnore...) { s.logger.Print("warning", "Not support on conflict:", scope.Dialect().GetName()) } - if insertMod != "" { - scope.Set("gorm:insert_modifier", insertMod) + + return scope.callCallbacks(s.parent.callbacks.creates).db +} + +// CreateMany Ignore only support MySQL and sqlite +// db.CreateMany([]interface{}{&user1, &user2, &user3}, gorm.IGNORE) +// db.CreateMany([]interface{}{&user1, &user2, &user3}) +func (s *DB) CreateMany(values []interface{}, updateOrIgnore ...interface{}) *DB { + var createMany [](map[string]interface{}) + for _, value := range values { + createMany = append(createMany, convertInterfaceToMap(value, false, s)) } - if updateStr != "" { - scope.Set("gorm:insert_option", updateStr) - if updateObj != nil { - updateMap := convertInterfaceToMap(updateObj, false, s) - scope.Set("gorm:on_conflict_update", updateMap) + scope := s.NewScope(values[0]).Set("gorm:create_many", createMany) + if len(updateOrIgnore) > 0 { + if !scope.onConflict(updateOrIgnore...) { + s.logger.Print("warning", "Not support on conflict:", scope.Dialect().GetName()) } } - return scope.callCallbacks(s.parent.callbacks.creates).db } diff --git a/scope.go b/scope.go index 5f49284e3..6fc43e0cf 100644 --- a/scope.go +++ b/scope.go @@ -1419,3 +1419,21 @@ func (scope *Scope) hasConditions() bool { len(scope.Search.orConditions) > 0 || len(scope.Search.notConditions) > 0 } + +func (scope *Scope) onConflict(updateOrIgnore ...interface{}) bool { + insertMod, updateStr, updateObj := scope.Dialect().OnConflict(updateOrIgnore...) + if insertMod == "" && updateStr == "" { + return false + } + if insertMod != "" { + scope.Set("gorm:insert_modifier", insertMod) + } + if updateStr != "" { + scope.Set("gorm:insert_option", updateStr) + if updateObj != nil { + updateMap := convertInterfaceToMap(updateObj, false, scope.db) + scope.Set("gorm:on_conflict_update", updateMap) + } + } + return true +}