Skip to content

Commit

Permalink
fix(dao): ensure find and create functions thread safe (#845)
Browse files Browse the repository at this point in the history
* fix(dao): ensure find and create functions thread safe

* feat(db/migrator): provide tool for merging duplicate page
  • Loading branch information
qwqcode authored Apr 30, 2024
1 parent fbd1cd9 commit a9331b4
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 23 deletions.
1 change: 1 addition & 0 deletions internal/dao/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
PageByKeySiteNameKey = "page#key=%s;site_name=%s"
CommentByIDKey = "comment#id=%d"
CommentChildIDsByIDKey = "comment_child_ids#id=%d"
NotifyByUserCommentKey = "notify#user_id=%d;comment_id=%d"
)

type DaoCache struct {
Expand Down
51 changes: 51 additions & 0 deletions internal/dao/migrate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dao

import (
"os"

"github.com/ArtalkJS/Artalk/internal/entity"
"github.com/ArtalkJS/Artalk/internal/log"
)
Expand All @@ -21,6 +23,11 @@ func (dao *Dao) MigrateModels() {
// because there are many different DBs and the implementation of foreign keys may be different,
// and the DB may not support foreign keys, so don't rely on the foreign key function of the DB system.
dao.DropConstraintsIfExist()

// Merge pages
if os.Getenv("ATK_DB_MIGRATOR_FUNC_MERGE_PAGES") == "1" {
dao.MergePages()
}
}

// Remove all constraints
Expand Down Expand Up @@ -87,3 +94,47 @@ func (dao *Dao) MigrateRootID() {

log.Info(TAG, "Root IDs generated successfully.")
}

func (dao *Dao) MergePages() {
// merge pages with same key and site_name, sum pv
pages := []*entity.Page{}

// load all pages
if err := dao.DB().Order("id ASC").Find(&pages).Error; err != nil {
log.Fatal("Failed to load pages. ", err.Error)
}
beforeLen := len(pages)

// merge pages
mergedPages := map[string]*entity.Page{}
for _, page := range pages {
key := page.SiteName + page.Key
if _, ok := mergedPages[key]; !ok {
mergedPages[key] = page
} else {
mergedPages[key].PV += page.PV
mergedPages[key].VoteUp += page.VoteUp
mergedPages[key].VoteDown += page.VoteDown
}
}

// delete all pages
dao.DB().Exec("DELETE FROM pages")

// insert merged pages
pages = []*entity.Page{}
for _, page := range mergedPages {
pages = append(pages, page)
}
if err := dao.DB().CreateInBatches(pages, 1000); err.Error != nil {
log.Fatal("Failed to insert merged pages. ", err.Error)
}

// drop page AccessibleURL column
if dao.DB().Migrator().HasColumn(&entity.Page{}, "accessible_url") {
dao.DB().Migrator().DropColumn(&entity.Page{}, "accessible_url")
}

log.Info("Pages merged successfully. Before pages: ", beforeLen, ", After pages: ", len(mergedPages), ", Deleted pages: ", beforeLen-len(mergedPages))
os.Exit(0)
}
83 changes: 61 additions & 22 deletions internal/dao/query_find_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,63 @@ import (
"strings"

"github.com/ArtalkJS/Artalk/internal/entity"
"github.com/ArtalkJS/Artalk/internal/log"
"github.com/ArtalkJS/Artalk/internal/utils"
"golang.org/x/sync/singleflight"
)

func (dao *Dao) FindCreateSite(siteName string) entity.Site {
site := dao.FindSite(siteName)
if site.IsEmpty() {
site = dao.NewSite(siteName, "")
var findCreateSingleFlightGroup = new(singleflight.Group)

type EntityHasIsEmpty interface {
IsEmpty() bool
}

// FindCreateAction (Thread Safe)
//
// Use singleflight.Group to prevent duplicate creation if multiple goroutines access at the same time.
func FindCreateAction[T EntityHasIsEmpty](
key string,
findAction func() (T, error),
createAction func() (T, error),
) (T, error) {
result, err, _ := findCreateSingleFlightGroup.Do(key, func() (any, error) {
r, err := findAction()
if err != nil {
return nil, err
}
if r.IsEmpty() {
if r, err = createAction(); err != nil {
return nil, err
}
}
return r, nil
})
if err != nil {
log.Error("[FindCreate] ", err)
return result.(T), err
}
return site
return result.(T), nil
}

func (dao *Dao) FindCreateSite(siteName string) entity.Site {
r, _ := FindCreateAction(fmt.Sprintf(SiteByNameKey, siteName), func() (entity.Site, error) {
return dao.FindSite(siteName), nil
}, func() (entity.Site, error) {
return dao.NewSite(siteName, ""), nil
})
return r
}

func (dao *Dao) FindCreatePage(pageKey string, pageTitle string, siteName string) entity.Page {
page := dao.FindPage(pageKey, siteName)
if page.IsEmpty() {
page = dao.NewPage(pageKey, pageTitle, siteName)
}
return page
r, _ := FindCreateAction(fmt.Sprintf(PageByKeySiteNameKey, pageKey, siteName), func() (entity.Page, error) {
return dao.FindPage(pageKey, siteName), nil
}, func() (entity.Page, error) {
return dao.NewPage(pageKey, pageTitle, siteName), nil
})
return r
}

func (dao *Dao) FindCreateUser(name string, email string, link string) (user entity.User, err error) {
func (dao *Dao) FindCreateUser(name string, email string, link string) (entity.User, error) {
name = strings.TrimSpace(name)
email = strings.TrimSpace(email)
link = strings.TrimSpace(link)
Expand All @@ -37,20 +74,22 @@ func (dao *Dao) FindCreateUser(name string, email string, link string) (user ent
if link != "" && !utils.ValidateURL(link) {
link = ""
}
user = dao.FindUser(name, email)
if user.IsEmpty() {
user, err = dao.NewUser(name, email, link) // save a new user
return FindCreateAction(fmt.Sprintf(UserByNameEmailKey, name, email), func() (entity.User, error) {
return dao.FindUser(name, email), nil
}, func() (entity.User, error) {
user, err := dao.NewUser(name, email, link) // save a new user
if err != nil {
return entity.User{}, err
}
}
return user, nil
return user, nil
})
}

func (dao *Dao) FindCreateNotify(userID uint, lookCommentID uint) entity.Notify {
notify := dao.FindNotify(userID, lookCommentID)
if notify.IsEmpty() {
notify = dao.NewNotify(userID, lookCommentID)
}
return notify
func (dao *Dao) FindCreateNotify(userID uint, commentID uint) entity.Notify {
r, _ := FindCreateAction(fmt.Sprintf(NotifyByUserCommentKey, userID, commentID), func() (entity.Notify, error) {
return dao.FindNotify(userID, commentID), nil
}, func() (entity.Notify, error) {
return dao.NewNotify(userID, commentID), nil
})
return r
}
101 changes: 100 additions & 1 deletion internal/dao/query_find_create_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package dao_test

import (
"fmt"
"sync"
"testing"
"time"

"github.com/ArtalkJS/Artalk/internal/dao"
"github.com/ArtalkJS/Artalk/test"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -55,7 +59,41 @@ func TestFindCreatePage(t *testing.T) {
t.Run("Find Existed Page", func(t *testing.T) {
result := app.Dao().FindCreatePage("/test/1000.html", "", "Site A")
assert.False(t, result.IsEmpty())
assert.Equal(t, app.Dao().FindPage("/test/1000.html", "Site A"), result)
findPage := app.Dao().FindPage("/test/1000.html", "Site A")
assert.Equal(t, app.Dao().CookPage(&findPage), app.Dao().CookPage(&result))
})

t.Run("Concurrent FindCreatePage", func(t *testing.T) {
var (
pageKey = "/" + time.Now().String() + ".html"
pageTitle = "New Page Title " + time.Now().String()
siteName = "Site A"
)

// simulate concurrent requests
var wg sync.WaitGroup

var idMap sync.Map
n := 10000
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
result := app.Dao().FindCreatePage(pageKey, pageTitle, siteName)
idMap.Store(result.ID, true)
}()
}

wg.Wait()

// count the number of different pages
count := 0
idMap.Range(func(_, _ interface{}) bool {
count++
return true
})

assert.Equal(t, 1, count, fmt.Sprintf("Concurrent FindCreatePage should return the same page, but got %d different pages", count))
})
}

Expand Down Expand Up @@ -113,3 +151,64 @@ func TestFindCreateUser(t *testing.T) {
assert.Equal(t, app.Dao().FindUser("userA", "[email protected]"), result)
})
}

type mockEntity struct {
ID int
Name string
}

func (e mockEntity) IsEmpty() bool {
return e.ID == 0
}

func TestFindCreateAction(t *testing.T) {
app, _ := test.NewTestApp()
defer app.Cleanup()

var calledTimes int32
var mutex sync.Mutex

increaseCalledTimes := func() {
mutex.Lock()
defer mutex.Unlock()
calledTimes++
}

t.Run("Concurrent FindCreateAction", func(t *testing.T) {
var wg sync.WaitGroup

var instance mockEntity
findCreateFunc := func() (mockEntity, error) {
randKey := fmt.Sprintf("rand_key_%d", time.Now().UnixNano())
return dao.FindCreateAction(randKey, func() (mockEntity, error) {
// findAction
time.Sleep(200 * time.Millisecond) // mock time consuming
return instance, nil
}, func() (mockEntity, error) {
// createAction
instance = mockEntity{
ID: 1,
Name: "mockEntity",
}
time.Sleep(500 * time.Millisecond) // mock time consuming
increaseCalledTimes()
return instance, nil
})
}

n := 10000
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
result, err := findCreateFunc()
assert.NoError(t, err)
assert.False(t, result.IsEmpty(), "FindCreateAction should always return a non-empty entity")
}()
}

wg.Wait()

assert.Equal(t, int32(1), calledTimes, "Concurrent FindCreateAction should only call createAction once")
})
}

0 comments on commit a9331b4

Please sign in to comment.