Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dao): ensure find and create functions thread safe #845

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/dao/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
PageByKeySiteNameKey = "page#key=%s;site_name=%s"
CommentByIDKey = "comment#id=%d"
CommentChildIDsByIDKey = "comment_child_ids#id=%d"
NotifyByUserCommentKey = "notify#user_id=%d;comment_id=%d"
)

type DaoCache struct {
Expand Down
51 changes: 51 additions & 0 deletions internal/dao/migrate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dao

import (
"os"

"github.com/ArtalkJS/Artalk/internal/entity"
"github.com/ArtalkJS/Artalk/internal/log"
)
Expand All @@ -21,6 +23,11 @@ func (dao *Dao) MigrateModels() {
// because there are many different DBs and the implementation of foreign keys may be different,
// and the DB may not support foreign keys, so don't rely on the foreign key function of the DB system.
dao.DropConstraintsIfExist()

// Merge pages
if os.Getenv("ATK_DB_MIGRATOR_FUNC_MERGE_PAGES") == "1" {
dao.MergePages()
}
}

// Remove all constraints
Expand Down Expand Up @@ -87,3 +94,47 @@ func (dao *Dao) MigrateRootID() {

log.Info(TAG, "Root IDs generated successfully.")
}

func (dao *Dao) MergePages() {
// merge pages with same key and site_name, sum pv
pages := []*entity.Page{}

// load all pages
if err := dao.DB().Order("id ASC").Find(&pages).Error; err != nil {
log.Fatal("Failed to load pages. ", err.Error)
}
beforeLen := len(pages)

// merge pages
mergedPages := map[string]*entity.Page{}
for _, page := range pages {
key := page.SiteName + page.Key
if _, ok := mergedPages[key]; !ok {
mergedPages[key] = page
} else {
mergedPages[key].PV += page.PV
mergedPages[key].VoteUp += page.VoteUp
mergedPages[key].VoteDown += page.VoteDown
}
}

// delete all pages
dao.DB().Exec("DELETE FROM pages")

// insert merged pages
pages = []*entity.Page{}
for _, page := range mergedPages {
pages = append(pages, page)
}
if err := dao.DB().CreateInBatches(pages, 1000); err.Error != nil {
log.Fatal("Failed to insert merged pages. ", err.Error)
}

// drop page AccessibleURL column
if dao.DB().Migrator().HasColumn(&entity.Page{}, "accessible_url") {
dao.DB().Migrator().DropColumn(&entity.Page{}, "accessible_url")
}

log.Info("Pages merged successfully. Before pages: ", beforeLen, ", After pages: ", len(mergedPages), ", Deleted pages: ", beforeLen-len(mergedPages))
os.Exit(0)
}
83 changes: 61 additions & 22 deletions internal/dao/query_find_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,63 @@ import (
"strings"

"github.com/ArtalkJS/Artalk/internal/entity"
"github.com/ArtalkJS/Artalk/internal/log"
"github.com/ArtalkJS/Artalk/internal/utils"
"golang.org/x/sync/singleflight"
)

func (dao *Dao) FindCreateSite(siteName string) entity.Site {
site := dao.FindSite(siteName)
if site.IsEmpty() {
site = dao.NewSite(siteName, "")
var findCreateSingleFlightGroup = new(singleflight.Group)

type EntityHasIsEmpty interface {
IsEmpty() bool
}

// FindCreateAction (Thread Safe)
//
// Use singleflight.Group to prevent duplicate creation if multiple goroutines access at the same time.
func FindCreateAction[T EntityHasIsEmpty](
key string,
findAction func() (T, error),
createAction func() (T, error),
) (T, error) {
result, err, _ := findCreateSingleFlightGroup.Do(key, func() (any, error) {
r, err := findAction()
if err != nil {
return nil, err
}
if r.IsEmpty() {
if r, err = createAction(); err != nil {
return nil, err
}
}
return r, nil
})
if err != nil {
log.Error("[FindCreate] ", err)
return result.(T), err
}
return site
return result.(T), nil
}

func (dao *Dao) FindCreateSite(siteName string) entity.Site {
r, _ := FindCreateAction(fmt.Sprintf(SiteByNameKey, siteName), func() (entity.Site, error) {
return dao.FindSite(siteName), nil
}, func() (entity.Site, error) {
return dao.NewSite(siteName, ""), nil
})
return r
}

func (dao *Dao) FindCreatePage(pageKey string, pageTitle string, siteName string) entity.Page {
page := dao.FindPage(pageKey, siteName)
if page.IsEmpty() {
page = dao.NewPage(pageKey, pageTitle, siteName)
}
return page
r, _ := FindCreateAction(fmt.Sprintf(PageByKeySiteNameKey, pageKey, siteName), func() (entity.Page, error) {
return dao.FindPage(pageKey, siteName), nil
}, func() (entity.Page, error) {
return dao.NewPage(pageKey, pageTitle, siteName), nil
})
return r
}

func (dao *Dao) FindCreateUser(name string, email string, link string) (user entity.User, err error) {
func (dao *Dao) FindCreateUser(name string, email string, link string) (entity.User, error) {
name = strings.TrimSpace(name)
email = strings.TrimSpace(email)
link = strings.TrimSpace(link)
Expand All @@ -37,20 +74,22 @@ func (dao *Dao) FindCreateUser(name string, email string, link string) (user ent
if link != "" && !utils.ValidateURL(link) {
link = ""
}
user = dao.FindUser(name, email)
if user.IsEmpty() {
user, err = dao.NewUser(name, email, link) // save a new user
return FindCreateAction(fmt.Sprintf(UserByNameEmailKey, name, email), func() (entity.User, error) {
return dao.FindUser(name, email), nil
}, func() (entity.User, error) {
user, err := dao.NewUser(name, email, link) // save a new user
if err != nil {
return entity.User{}, err
}
}
return user, nil
return user, nil
})
}

func (dao *Dao) FindCreateNotify(userID uint, lookCommentID uint) entity.Notify {
notify := dao.FindNotify(userID, lookCommentID)
if notify.IsEmpty() {
notify = dao.NewNotify(userID, lookCommentID)
}
return notify
func (dao *Dao) FindCreateNotify(userID uint, commentID uint) entity.Notify {
r, _ := FindCreateAction(fmt.Sprintf(NotifyByUserCommentKey, userID, commentID), func() (entity.Notify, error) {
return dao.FindNotify(userID, commentID), nil
}, func() (entity.Notify, error) {
return dao.NewNotify(userID, commentID), nil
})
return r
}
101 changes: 100 additions & 1 deletion internal/dao/query_find_create_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package dao_test

import (
"fmt"
"sync"
"testing"
"time"

"github.com/ArtalkJS/Artalk/internal/dao"
"github.com/ArtalkJS/Artalk/test"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -55,7 +59,41 @@ func TestFindCreatePage(t *testing.T) {
t.Run("Find Existed Page", func(t *testing.T) {
result := app.Dao().FindCreatePage("/test/1000.html", "", "Site A")
assert.False(t, result.IsEmpty())
assert.Equal(t, app.Dao().FindPage("/test/1000.html", "Site A"), result)
findPage := app.Dao().FindPage("/test/1000.html", "Site A")
assert.Equal(t, app.Dao().CookPage(&findPage), app.Dao().CookPage(&result))
})

t.Run("Concurrent FindCreatePage", func(t *testing.T) {
var (
pageKey = "/" + time.Now().String() + ".html"
pageTitle = "New Page Title " + time.Now().String()
siteName = "Site A"
)

// simulate concurrent requests
var wg sync.WaitGroup

var idMap sync.Map
n := 10000
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
result := app.Dao().FindCreatePage(pageKey, pageTitle, siteName)
idMap.Store(result.ID, true)
}()
}

wg.Wait()

// count the number of different pages
count := 0
idMap.Range(func(_, _ interface{}) bool {
count++
return true
})

assert.Equal(t, 1, count, fmt.Sprintf("Concurrent FindCreatePage should return the same page, but got %d different pages", count))
})
}

Expand Down Expand Up @@ -113,3 +151,64 @@ func TestFindCreateUser(t *testing.T) {
assert.Equal(t, app.Dao().FindUser("userA", "[email protected]"), result)
})
}

type mockEntity struct {
ID int
Name string
}

func (e mockEntity) IsEmpty() bool {
return e.ID == 0
}

func TestFindCreateAction(t *testing.T) {
app, _ := test.NewTestApp()
defer app.Cleanup()

var calledTimes int32
var mutex sync.Mutex

increaseCalledTimes := func() {
mutex.Lock()
defer mutex.Unlock()
calledTimes++
}

t.Run("Concurrent FindCreateAction", func(t *testing.T) {
var wg sync.WaitGroup

var instance mockEntity
findCreateFunc := func() (mockEntity, error) {
randKey := fmt.Sprintf("rand_key_%d", time.Now().UnixNano())
return dao.FindCreateAction(randKey, func() (mockEntity, error) {
// findAction
time.Sleep(200 * time.Millisecond) // mock time consuming
return instance, nil
}, func() (mockEntity, error) {
// createAction
instance = mockEntity{
ID: 1,
Name: "mockEntity",
}
time.Sleep(500 * time.Millisecond) // mock time consuming
increaseCalledTimes()
return instance, nil
})
}

n := 10000
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
result, err := findCreateFunc()
assert.NoError(t, err)
assert.False(t, result.IsEmpty(), "FindCreateAction should always return a non-empty entity")
}()
}

wg.Wait()

assert.Equal(t, int32(1), calledTimes, "Concurrent FindCreateAction should only call createAction once")
})
}