diff --git a/internal/dao/cache.go b/internal/dao/cache.go index d9baf64d..10959a6d 100644 --- a/internal/dao/cache.go +++ b/internal/dao/cache.go @@ -19,6 +19,7 @@ const ( PageByKeySiteNameKey = "page#key=%s;site_name=%s" CommentByIDKey = "comment#id=%d" CommentChildIDsByIDKey = "comment_child_ids#id=%d" + NotifyByUserCommentKey = "notify#user_id=%d;comment_id=%d" ) type DaoCache struct { diff --git a/internal/dao/migrate.go b/internal/dao/migrate.go index 82a5babf..718db55f 100644 --- a/internal/dao/migrate.go +++ b/internal/dao/migrate.go @@ -1,6 +1,8 @@ package dao import ( + "os" + "github.com/ArtalkJS/Artalk/internal/entity" "github.com/ArtalkJS/Artalk/internal/log" ) @@ -21,6 +23,11 @@ func (dao *Dao) MigrateModels() { // because there are many different DBs and the implementation of foreign keys may be different, // and the DB may not support foreign keys, so don't rely on the foreign key function of the DB system. dao.DropConstraintsIfExist() + + // Merge pages + if os.Getenv("ATK_DB_MIGRATOR_FUNC_MERGE_PAGES") == "1" { + dao.MergePages() + } } // Remove all constraints @@ -87,3 +94,47 @@ func (dao *Dao) MigrateRootID() { log.Info(TAG, "Root IDs generated successfully.") } + +func (dao *Dao) MergePages() { + // merge pages with same key and site_name, sum pv + pages := []*entity.Page{} + + // load all pages + if err := dao.DB().Order("id ASC").Find(&pages).Error; err != nil { + log.Fatal("Failed to load pages. ", err.Error) + } + beforeLen := len(pages) + + // merge pages + mergedPages := map[string]*entity.Page{} + for _, page := range pages { + key := page.SiteName + page.Key + if _, ok := mergedPages[key]; !ok { + mergedPages[key] = page + } else { + mergedPages[key].PV += page.PV + mergedPages[key].VoteUp += page.VoteUp + mergedPages[key].VoteDown += page.VoteDown + } + } + + // delete all pages + dao.DB().Exec("DELETE FROM pages") + + // insert merged pages + pages = []*entity.Page{} + for _, page := range mergedPages { + pages = append(pages, page) + } + if err := dao.DB().CreateInBatches(pages, 1000); err.Error != nil { + log.Fatal("Failed to insert merged pages. ", err.Error) + } + + // drop page AccessibleURL column + if dao.DB().Migrator().HasColumn(&entity.Page{}, "accessible_url") { + dao.DB().Migrator().DropColumn(&entity.Page{}, "accessible_url") + } + + log.Info("Pages merged successfully. Before pages: ", beforeLen, ", After pages: ", len(mergedPages), ", Deleted pages: ", beforeLen-len(mergedPages)) + os.Exit(0) +} diff --git a/internal/dao/query_find_create.go b/internal/dao/query_find_create.go index 8bcfab48..efbfc583 100644 --- a/internal/dao/query_find_create.go +++ b/internal/dao/query_find_create.go @@ -5,26 +5,63 @@ import ( "strings" "github.com/ArtalkJS/Artalk/internal/entity" + "github.com/ArtalkJS/Artalk/internal/log" "github.com/ArtalkJS/Artalk/internal/utils" + "golang.org/x/sync/singleflight" ) -func (dao *Dao) FindCreateSite(siteName string) entity.Site { - site := dao.FindSite(siteName) - if site.IsEmpty() { - site = dao.NewSite(siteName, "") +var findCreateSingleFlightGroup = new(singleflight.Group) + +type EntityHasIsEmpty interface { + IsEmpty() bool +} + +// FindCreateAction (Thread Safe) +// +// Use singleflight.Group to prevent duplicate creation if multiple goroutines access at the same time. +func FindCreateAction[T EntityHasIsEmpty]( + key string, + findAction func() (T, error), + createAction func() (T, error), +) (T, error) { + result, err, _ := findCreateSingleFlightGroup.Do(key, func() (any, error) { + r, err := findAction() + if err != nil { + return nil, err + } + if r.IsEmpty() { + if r, err = createAction(); err != nil { + return nil, err + } + } + return r, nil + }) + if err != nil { + log.Error("[FindCreate] ", err) + return result.(T), err } - return site + return result.(T), nil +} + +func (dao *Dao) FindCreateSite(siteName string) entity.Site { + r, _ := FindCreateAction(fmt.Sprintf(SiteByNameKey, siteName), func() (entity.Site, error) { + return dao.FindSite(siteName), nil + }, func() (entity.Site, error) { + return dao.NewSite(siteName, ""), nil + }) + return r } func (dao *Dao) FindCreatePage(pageKey string, pageTitle string, siteName string) entity.Page { - page := dao.FindPage(pageKey, siteName) - if page.IsEmpty() { - page = dao.NewPage(pageKey, pageTitle, siteName) - } - return page + r, _ := FindCreateAction(fmt.Sprintf(PageByKeySiteNameKey, pageKey, siteName), func() (entity.Page, error) { + return dao.FindPage(pageKey, siteName), nil + }, func() (entity.Page, error) { + return dao.NewPage(pageKey, pageTitle, siteName), nil + }) + return r } -func (dao *Dao) FindCreateUser(name string, email string, link string) (user entity.User, err error) { +func (dao *Dao) FindCreateUser(name string, email string, link string) (entity.User, error) { name = strings.TrimSpace(name) email = strings.TrimSpace(email) link = strings.TrimSpace(link) @@ -37,20 +74,22 @@ func (dao *Dao) FindCreateUser(name string, email string, link string) (user ent if link != "" && !utils.ValidateURL(link) { link = "" } - user = dao.FindUser(name, email) - if user.IsEmpty() { - user, err = dao.NewUser(name, email, link) // save a new user + return FindCreateAction(fmt.Sprintf(UserByNameEmailKey, name, email), func() (entity.User, error) { + return dao.FindUser(name, email), nil + }, func() (entity.User, error) { + user, err := dao.NewUser(name, email, link) // save a new user if err != nil { return entity.User{}, err } - } - return user, nil + return user, nil + }) } -func (dao *Dao) FindCreateNotify(userID uint, lookCommentID uint) entity.Notify { - notify := dao.FindNotify(userID, lookCommentID) - if notify.IsEmpty() { - notify = dao.NewNotify(userID, lookCommentID) - } - return notify +func (dao *Dao) FindCreateNotify(userID uint, commentID uint) entity.Notify { + r, _ := FindCreateAction(fmt.Sprintf(NotifyByUserCommentKey, userID, commentID), func() (entity.Notify, error) { + return dao.FindNotify(userID, commentID), nil + }, func() (entity.Notify, error) { + return dao.NewNotify(userID, commentID), nil + }) + return r } diff --git a/internal/dao/query_find_create_test.go b/internal/dao/query_find_create_test.go index 0ef4c074..bd807a7f 100644 --- a/internal/dao/query_find_create_test.go +++ b/internal/dao/query_find_create_test.go @@ -1,8 +1,12 @@ package dao_test import ( + "fmt" + "sync" "testing" + "time" + "github.com/ArtalkJS/Artalk/internal/dao" "github.com/ArtalkJS/Artalk/test" "github.com/stretchr/testify/assert" ) @@ -55,7 +59,41 @@ func TestFindCreatePage(t *testing.T) { t.Run("Find Existed Page", func(t *testing.T) { result := app.Dao().FindCreatePage("/test/1000.html", "", "Site A") assert.False(t, result.IsEmpty()) - assert.Equal(t, app.Dao().FindPage("/test/1000.html", "Site A"), result) + findPage := app.Dao().FindPage("/test/1000.html", "Site A") + assert.Equal(t, app.Dao().CookPage(&findPage), app.Dao().CookPage(&result)) + }) + + t.Run("Concurrent FindCreatePage", func(t *testing.T) { + var ( + pageKey = "/" + time.Now().String() + ".html" + pageTitle = "New Page Title " + time.Now().String() + siteName = "Site A" + ) + + // simulate concurrent requests + var wg sync.WaitGroup + + var idMap sync.Map + n := 10000 + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + result := app.Dao().FindCreatePage(pageKey, pageTitle, siteName) + idMap.Store(result.ID, true) + }() + } + + wg.Wait() + + // count the number of different pages + count := 0 + idMap.Range(func(_, _ interface{}) bool { + count++ + return true + }) + + assert.Equal(t, 1, count, fmt.Sprintf("Concurrent FindCreatePage should return the same page, but got %d different pages", count)) }) } @@ -113,3 +151,64 @@ func TestFindCreateUser(t *testing.T) { assert.Equal(t, app.Dao().FindUser("userA", "user_a@qwqaq.com"), result) }) } + +type mockEntity struct { + ID int + Name string +} + +func (e mockEntity) IsEmpty() bool { + return e.ID == 0 +} + +func TestFindCreateAction(t *testing.T) { + app, _ := test.NewTestApp() + defer app.Cleanup() + + var calledTimes int32 + var mutex sync.Mutex + + increaseCalledTimes := func() { + mutex.Lock() + defer mutex.Unlock() + calledTimes++ + } + + t.Run("Concurrent FindCreateAction", func(t *testing.T) { + var wg sync.WaitGroup + + var instance mockEntity + findCreateFunc := func() (mockEntity, error) { + randKey := fmt.Sprintf("rand_key_%d", time.Now().UnixNano()) + return dao.FindCreateAction(randKey, func() (mockEntity, error) { + // findAction + time.Sleep(200 * time.Millisecond) // mock time consuming + return instance, nil + }, func() (mockEntity, error) { + // createAction + instance = mockEntity{ + ID: 1, + Name: "mockEntity", + } + time.Sleep(500 * time.Millisecond) // mock time consuming + increaseCalledTimes() + return instance, nil + }) + } + + n := 10000 + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + result, err := findCreateFunc() + assert.NoError(t, err) + assert.False(t, result.IsEmpty(), "FindCreateAction should always return a non-empty entity") + }() + } + + wg.Wait() + + assert.Equal(t, int32(1), calledTimes, "Concurrent FindCreateAction should only call createAction once") + }) +}