From 375552ca8b9a0b8305c6e8cb7314dc8020119311 Mon Sep 17 00:00:00 2001 From: Wenbo han Date: Sat, 8 Apr 2023 21:49:22 +0800 Subject: [PATCH] Add new methods to Orm --- contracts/database/orm/orm.go | 3 + database/gorm/gorm.go | 291 ++++++++++++++++++++++++---------- database/gorm/gorm_test.go | 120 ++++++++++++++ 3 files changed, 327 insertions(+), 87 deletions(-) diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index 89b9dfc65..4ab5e2026 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -32,6 +32,7 @@ type Query interface { Distinct(args ...any) Query Exec(sql string, values ...any) (*Result, error) Find(dest any, conds ...any) error + FindOrFail(dest any, conds ...any) error First(dest any) error FirstOrCreate(dest any, conds ...any) error FirstOr(dest any, callback func() error) error @@ -54,6 +55,7 @@ type Query interface { Pluck(column string, dest any) error Raw(sql string, values ...any) Query Save(value any) error + SaveQuietly(value any) error Scan(dest any) error Scopes(funcs ...func(Query) Query) Query Select(query any, args ...any) Query @@ -62,6 +64,7 @@ type Query interface { Updates(values any) (*Result, error) UpdateOrCreate(dest any, attributes any, values any) error Where(query any, args ...any) Query + WithoutEvents() Query WithTrashed() Query With(query string, args ...any) Query } diff --git a/database/gorm/gorm.go b/database/gorm/gorm.go index da97129fc..785def16f 100644 --- a/database/gorm/gorm.go +++ b/database/gorm/gorm.go @@ -119,29 +119,13 @@ func readWriteSeparate(connection string, instance *gorm.DB, readConfigs, writeC })) } -func NewQuery(ctx context.Context, connection string) (*Query, error) { - db, err := New(connection) - if err != nil { - return nil, err - } - if db == nil { - return nil, nil - } - - if ctx != nil { - db = db.WithContext(ctx) - } - - return NewQueryInstance(db), nil -} - type Transaction struct { contractsorm.Query instance *gorm.DB } -func NewTransaction(instance *gorm.DB) *Transaction { - return &Transaction{Query: NewQueryInstance(instance), instance: instance} +func NewTransaction(tx *gorm.DB) *Transaction { + return &Transaction{Query: NewQueryWithInstance(nil, tx), instance: tx} } func (r *Transaction) Commit() error { @@ -153,11 +137,36 @@ func (r *Transaction) Rollback() error { } type Query struct { - instance *gorm.DB + instance *gorm.DB + withoutEvents bool } -func NewQueryInstance(instance *gorm.DB) *Query { - return &Query{instance} +func NewQuery(ctx context.Context, connection string) (*Query, error) { + db, err := New(connection) + if err != nil { + return nil, err + } + if db == nil { + return nil, nil + } + + if ctx != nil { + db = db.WithContext(ctx) + } + + return NewQueryWithInstance(nil, db), nil +} + +func NewQueryWithInstance(query *Query, instance *gorm.DB) *Query { + if query == nil { + return &Query{instance: instance} + } + + return &Query{instance: instance, withoutEvents: query.withoutEvents} +} + +func NewQueryWithWithoutEvents(query *Query) *Query { + return &Query{instance: query.instance, withoutEvents: true} } func (r *Query) Association(association string) contractsorm.Association { @@ -195,10 +204,8 @@ func (r *Query) Create(value any) error { } func (r *Query) Delete(dest any, conds ...any) (*contractsorm.Result, error) { - if deletingModel, ok := dest.(contractsorm.Deleting); ok { - if err := deletingModel.Deleting(r); err != nil { - return nil, err - } + if err := deleting(r, dest); err != nil { + return nil, err } res := r.instance.Delete(dest, conds...) @@ -206,10 +213,8 @@ func (r *Query) Delete(dest any, conds ...any) (*contractsorm.Result, error) { return nil, res.Error } - if deletedModel, ok := dest.(contractsorm.Deleted); ok { - if err := deletedModel.Deleted(r); err != nil { - return nil, err - } + if err := deleted(r, dest); err != nil { + return nil, err } return &contractsorm.Result{ @@ -220,7 +225,7 @@ func (r *Query) Delete(dest any, conds ...any) (*contractsorm.Result, error) { func (r *Query) Distinct(args ...any) contractsorm.Query { tx := r.instance.Distinct(args...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Exec(sql string, values ...any) (*contractsorm.Result, error) { @@ -232,23 +237,9 @@ func (r *Query) Exec(sql string, values ...any) (*contractsorm.Result, error) { } func (r *Query) Find(dest any, conds ...any) error { - if len(conds) > 0 { - switch cond := conds[0].(type) { - case string: - if cond == "" { - return ErrorMissingWhereClause - } - default: - reflectValue := reflect.Indirect(reflect.ValueOf(cond)) - switch reflectValue.Kind() { - case reflect.Slice, reflect.Array: - if reflectValue.Len() == 0 { - return ErrorMissingWhereClause - } - } - } + if err := filterFindConditions(conds...); err != nil { + return err } - if err := r.instance.Find(dest, conds...).Error; err != nil { return err } @@ -256,6 +247,23 @@ func (r *Query) Find(dest any, conds ...any) error { return retrieved(r, dest) } +func (r *Query) FindOrFail(dest any, conds ...any) error { + if err := filterFindConditions(conds...); err != nil { + return err + } + + res := r.instance.Find(dest, conds...) + if err := res.Error; err != nil { + return err + } + + if res.RowsAffected == 0 { + return orm.ErrRecordNotFound + } + + return retrieved(r, dest) +} + func (r *Query) First(dest any) error { res := r.instance.First(dest) if res.Error != nil { @@ -336,10 +344,8 @@ func (r *Query) FirstOrNew(dest any, attributes any, values ...any) error { } func (r *Query) ForceDelete(value any, conds ...any) (*contractsorm.Result, error) { - if forceDeletingModel, ok := value.(contractsorm.ForceDeleting); ok { - if err := forceDeletingModel.ForceDeleting(r); err != nil { - return nil, err - } + if err := forceDeleting(r, value); err != nil { + return nil, err } res := r.instance.Unscoped().Delete(value, conds...) @@ -348,10 +354,8 @@ func (r *Query) ForceDelete(value any, conds ...any) (*contractsorm.Result, erro } if res.RowsAffected > 0 { - if forceDeletedModel, ok := value.(contractsorm.ForceDeleted); ok { - if err := forceDeletedModel.ForceDeleted(r); err != nil { - return nil, err - } + if err := forceDeleted(r, value); err != nil { + return nil, err } } @@ -367,13 +371,13 @@ func (r *Query) Get(dest any) error { func (r *Query) Group(name string) contractsorm.Query { tx := r.instance.Group(name) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Having(query any, args ...any) contractsorm.Query { tx := r.instance.Having(query, args...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Instance() *gorm.DB { @@ -383,13 +387,13 @@ func (r *Query) Instance() *gorm.DB { func (r *Query) Join(query string, args ...any) contractsorm.Query { tx := r.instance.Joins(query, args...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Limit(limit int) contractsorm.Query { tx := r.instance.Limit(limit) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Load(model any, relation string, args ...any) error { @@ -455,31 +459,31 @@ func (r *Query) LoadMissing(model any, relation string, args ...any) error { func (r *Query) Model(value any) contractsorm.Query { tx := r.instance.Model(value) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Offset(offset int) contractsorm.Query { tx := r.instance.Offset(offset) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Omit(columns ...string) contractsorm.Query { tx := r.instance.Omit(columns...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Order(value any) contractsorm.Query { tx := r.instance.Order(value) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) OrWhere(query any, args ...any) contractsorm.Query { tx := r.instance.Or(query, args...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Paginate(page, limit int, dest any, total *int64) error { @@ -506,7 +510,7 @@ func (r *Query) Pluck(column string, dest any) error { func (r *Query) Raw(sql string, values ...any) contractsorm.Query { tx := r.instance.Raw(sql, values...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Save(value any) error { @@ -525,6 +529,10 @@ func (r *Query) Save(value any) error { return r.save(value) } +func (r *Query) SaveQuietly(value any) error { + return r.WithoutEvents().Save(value) +} + func (r *Query) Scan(dest any) error { return r.instance.Scan(dest).Error } @@ -532,13 +540,28 @@ func (r *Query) Scan(dest any) error { func (r *Query) Select(query any, args ...any) contractsorm.Query { tx := r.instance.Select(query, args...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) +} + +func (r *Query) Scopes(funcs ...func(contractsorm.Query) contractsorm.Query) contractsorm.Query { + var gormFuncs []func(*gorm.DB) *gorm.DB + for _, item := range funcs { + gormFuncs = append(gormFuncs, func(tx *gorm.DB) *gorm.DB { + item(NewQueryWithInstance(r, tx)) + + return tx + }) + } + + tx := r.instance.Scopes(gormFuncs...) + + return NewQueryWithInstance(r, tx) } func (r *Query) Table(name string, args ...any) contractsorm.Query { tx := r.instance.Table(name, args...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) Update(column string, value any) error { @@ -605,13 +628,17 @@ func (r *Query) UpdateOrCreate(dest any, attributes any, values any) error { func (r *Query) Where(query any, args ...any) contractsorm.Query { tx := r.instance.Where(query, args...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) +} + +func (r *Query) WithoutEvents() contractsorm.Query { + return NewQueryWithWithoutEvents(r) } func (r *Query) WithTrashed() contractsorm.Query { tx := r.instance.Unscoped() - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) With(query string, args ...any) contractsorm.Query { @@ -619,8 +646,8 @@ func (r *Query) With(query string, args ...any) contractsorm.Query { switch arg := args[0].(type) { case func(contractsorm.Query) contractsorm.Query: newArgs := []any{ - func(db *gorm.DB) *gorm.DB { - query := arg(NewQueryInstance(db)) + func(tx *gorm.DB) *gorm.DB { + query := arg(NewQueryWithInstance(r, tx)) return query.(*Query).instance }, @@ -628,28 +655,13 @@ func (r *Query) With(query string, args ...any) contractsorm.Query { tx := r.instance.Preload(query, newArgs...) - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } } tx := r.instance.Preload(query, args...) - return NewQueryInstance(tx) -} - -func (r *Query) Scopes(funcs ...func(contractsorm.Query) contractsorm.Query) contractsorm.Query { - var gormFuncs []func(*gorm.DB) *gorm.DB - for _, item := range funcs { - gormFuncs = append(gormFuncs, func(db *gorm.DB) *gorm.DB { - item(&Query{db}) - - return db - }) - } - - tx := r.instance.Scopes(gormFuncs...) - - return NewQueryInstance(tx) + return NewQueryWithInstance(r, tx) } func (r *Query) selectCreate(value any) error { @@ -826,6 +838,10 @@ func (r *Query) save(value any) error { } func retrieved(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + if retrievedModel, ok := dest.(contractsorm.Retrieved); ok { if err := retrievedModel.Retrieved(query); err != nil { return err @@ -836,6 +852,10 @@ func retrieved(query *Query, dest any) error { } func updating(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + if updatingModel, ok := dest.(contractsorm.Updating); ok { if err := updatingModel.Updating(query); err != nil { return err @@ -846,6 +866,10 @@ func updating(query *Query, dest any) error { } func updated(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + if updatedModel, ok := dest.(contractsorm.Updated); ok { if err := updatedModel.Updated(query); err != nil { return err @@ -856,6 +880,10 @@ func updated(query *Query, dest any) error { } func saving(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + if savingModel, ok := dest.(contractsorm.Saving); ok { if err := savingModel.Saving(query); err != nil { return err @@ -866,6 +894,10 @@ func saving(query *Query, dest any) error { } func saved(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + if savedModel, ok := dest.(contractsorm.Saved); ok { if err := savedModel.Saved(query); err != nil { return err @@ -876,6 +908,10 @@ func saved(query *Query, dest any) error { } func creating(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + if creatingModel, ok := dest.(contractsorm.Creating); ok { if err := creatingModel.Creating(query); err != nil { return err @@ -886,6 +922,10 @@ func creating(query *Query, dest any) error { } func created(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + if createdModel, ok := dest.(contractsorm.Created); ok { if err := createdModel.Created(query); err != nil { return err @@ -895,6 +935,62 @@ func created(query *Query, dest any) error { return nil } +func deleting(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + + if deletingModel, ok := dest.(contractsorm.Deleting); ok { + if err := deletingModel.Deleting(query); err != nil { + return err + } + } + + return nil +} + +func deleted(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + + if deletedModel, ok := dest.(contractsorm.Deleted); ok { + if err := deletedModel.Deleted(query); err != nil { + return err + } + } + + return nil +} + +func forceDeleting(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + + if forceDeletingModel, ok := dest.(contractsorm.ForceDeleting); ok { + if err := forceDeletingModel.ForceDeleting(query); err != nil { + return err + } + } + + return nil +} + +func forceDeleted(query *Query, dest any) error { + if query.withoutEvents { + return nil + } + + if forceDeletedModel, ok := dest.(contractsorm.ForceDeleted); ok { + if err := forceDeletedModel.ForceDeleted(query); err != nil { + return err + } + } + + return nil +} + func create(query *Query, dest any) error { if err := saving(query, dest); err != nil { return err @@ -916,3 +1012,24 @@ func create(query *Query, dest any) error { return nil } + +func filterFindConditions(conds ...any) error { + if len(conds) > 0 { + switch cond := conds[0].(type) { + case string: + if cond == "" { + return ErrorMissingWhereClause + } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(cond)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if reflectValue.Len() == 0 { + return ErrorMissingWhereClause + } + } + } + } + + return nil +} diff --git a/database/gorm/gorm_test.go b/database/gorm/gorm_test.go index 54b5c6c52..520444eb1 100644 --- a/database/gorm/gorm_test.go +++ b/database/gorm/gorm_test.go @@ -60,6 +60,12 @@ func (u *User) Saving(query contractsorm.Query) error { if u.Name == "event_save_update_save_name" { u.Avatar = "event_save_update_save_avatar" } + if u.Name == "event_save_without_name" { + u.Avatar = "event_save_without_avatar" + } + if u.Name == "event_save_quietly_name" { + u.Avatar = "event_save_quietly_avatar" + } return nil } @@ -71,6 +77,12 @@ func (u *User) Saved(query contractsorm.Query) error { if u.Name == "event_save_update_save_name" { u.Avatar = u.Avatar + "1" } + if u.Name == "event_save_without_name" { + u.Avatar = "event_saved_without_avatar" + } + if u.Name == "event_save_quietly_name" { + u.Avatar = "event_saved_quietly_avatar" + } return nil } @@ -132,6 +144,9 @@ func (u *User) Retrieved(query contractsorm.Query) error { if u.Name == "event_retrieve_first_or_new_name" { u.Name = "event_retrieved_first_or_new_name" } + if u.Name == "event_retrieve_find_or_fail_name" { + u.Name = "event_retrieved_find_or_fail_name" + } return nil } @@ -801,6 +816,53 @@ func (s *GormQueryTestSuite) TestFind() { } } +func (s *GormQueryTestSuite) TestFindOrFail() { + for _, query := range s.queries { + tests := []struct { + name string + setup func() + }{ + { + name: "success", + setup: func() { + user := User{Name: "find_user"} + s.Nil(query.Create(&user)) + s.True(user.ID > 0) + + var user2 User + s.Nil(query.FindOrFail(&user2, user.ID)) + s.True(user2.ID > 0) + }, + }, + { + name: "error", + setup: func() { + var user User + s.ErrorIs(query.FindOrFail(&user, 10000), orm.ErrRecordNotFound) + }, + }, + { + name: "success with event", + setup: func() { + user := User{Name: "event_retrieve_find_or_fail_name", Avatar: "find_or_fail_avatar"} + s.Nil(query.Create(&user)) + s.True(user.ID > 0) + + var user1 User + s.Nil(query.Where("name", "event_retrieve_find_or_fail_name").Find(&user1)) + s.True(user1.ID > 0) + s.Equal("event_retrieved_find_or_fail_name", user1.Name) + }, + }, + } + for _, test := range tests { + s.Run(test.name, func() { + test.setup() + }) + } + } +} + func (s *GormQueryTestSuite) TestFirst() { for _, query := range s.queries { tests := []struct { @@ -1703,7 +1765,36 @@ func (s *GormQueryTestSuite) TestSave() { test.setup() }) } + } +} + +func (s *GormQueryTestSuite) TestSaveQuietly() { + for _, query := range s.queries { + tests := []struct { + name string + setup func() + }{ + { + name: "success", + setup: func() { + user := User{Name: "event_save_quietly_name", Avatar: "save_quietly_avatar"} + s.Nil(query.SaveQuietly(&user)) + s.True(user.ID > 0) + s.Equal("event_save_quietly_name", user.Name) + s.Equal("save_quietly_avatar", user.Avatar) + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.Equal("event_save_quietly_name", user1.Name) + s.Equal("save_quietly_avatar", user1.Avatar) + }, + }, + } + for _, test := range tests { + s.Run(test.name, func() { + test.setup() + }) + } } } @@ -1921,6 +2012,35 @@ func (s *GormQueryTestSuite) TestWhere() { } } +func (s *GormQueryTestSuite) TestWithoutEvents() { + for _, query := range s.queries { + tests := []struct { + name string + setup func() + }{ + { + name: "success", + setup: func() { + user := User{Name: "event_save_without_name", Avatar: "without_events_avatar"} + s.Nil(query.WithoutEvents().Save(&user)) + s.True(user.ID > 0) + s.Equal("without_events_avatar", user.Avatar) + + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.Equal("event_save_without_name", user1.Name) + s.Equal("without_events_avatar", user1.Avatar) + }, + }, + } + for _, test := range tests { + s.Run(test.name, func() { + test.setup() + }) + } + } +} + func (s *GormQueryTestSuite) TestWith() { for driver, query := range s.queries { s.Run(driver.String(), func() {