Skip to content

Commit

Permalink
fix(db): sql error with custom table prefix (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
qwqcode authored Sep 4, 2024
1 parent 90d8503 commit 4eb3dba
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 21 deletions.
2 changes: 1 addition & 1 deletion internal/artransfer/importer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func Test_importArtrans(t *testing.T) {

// Mock db error
// make comments table unique on content
dao.DB().Exec("CREATE UNIQUE INDEX idx_comments_content ON comments (content)")
dao.DB().Exec("CREATE UNIQUE INDEX idx_comments_content ON " + dao.GetTableName(&entity.Comment{}) + " (content)")

err := RunImportArtrans(dao, &params)
assert.Error(t, err, "Import should be failed")
Expand Down
13 changes: 7 additions & 6 deletions internal/dao/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,26 @@ func (dao *Dao) MigrateRootID() {
dao.DB().Migrator().AddColumn(&entity.Comment{}, "root_id")
}

tbComments := dao.GetTableName(&entity.Comment{})
if err := dao.DB().Raw(`WITH RECURSIVE CommentHierarchy AS (
SELECT id, id AS root_id, rid
FROM comments
FROM ` + tbComments + `
WHERE rid = 0
UNION ALL
SELECT c.id, ch.root_id, c.rid
FROM comments c
FROM ` + tbComments + ` c
INNER JOIN CommentHierarchy ch ON c.rid = ch.id
)
UPDATE comments SET root_id = (
UPDATE ` + tbComments + ` SET root_id = (
SELECT root_id
FROM CommentHierarchy
WHERE comments.id = CommentHierarchy.id
WHERE ` + tbComments + `.id = CommentHierarchy.id
);
`).Scan(&struct{}{}).Error; err == nil {
// no error, then do some patch
dao.DB().Table("comments").Where("id = root_id").Update("root_id", 0)
dao.DB().Model(&entity.Comment{}).Where("id = root_id").Update("root_id", 0)
} else {
// try backup plan (if recursive CTE is not supported)
log.Info(TAG, "Recursive CTE is not supported, trying backup plan... Please wait a moment. This may take a long time if there are many comments.")
Expand Down Expand Up @@ -133,7 +134,7 @@ func (dao *Dao) MergePages() {
}

// delete all pages
dao.DB().Exec("DELETE FROM pages")
dao.DB().Where("1 = 1").Delete(&entity.Page{})

// insert merged pages
pages = []*entity.Page{}
Expand Down
13 changes: 12 additions & 1 deletion internal/dao/query.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
package dao

import "github.com/ArtalkJS/Artalk/internal/entity"
import (
"github.com/ArtalkJS/Artalk/internal/entity"
"gorm.io/gorm"
)

func (dao *Dao) GetUserAllCommentIDs(userID uint) []uint {
userAllCommentIDs := []uint{}
dao.DB().Model(&entity.Comment{}).Select("id").Where("user_id = ?", userID).Find(&userAllCommentIDs)
return userAllCommentIDs
}

// Get the table name of the entity
func (dao *Dao) GetTableName(entity any) string {
// @see https://github.com/go-gorm/gorm/issues/3603#issuecomment-709883403
stmt := &gorm.Statement{DB: dao.DB()}
stmt.Parse(entity)
return stmt.Schema.Table
}
17 changes: 17 additions & 0 deletions internal/dao/query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dao_test

import (
"testing"

"github.com/ArtalkJS/Artalk/internal/entity"
"github.com/ArtalkJS/Artalk/test"
"github.com/stretchr/testify/assert"
)

func TestGetTableName(t *testing.T) {
app, _ := test.NewTestApp()
defer app.Cleanup()

assert.Equal(t, "atk_pages", app.Dao().GetTableName(&entity.Page{}))
assert.Equal(t, "atk_comments", app.Dao().GetTableName(&entity.Comment{}))
}
7 changes: 6 additions & 1 deletion internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ func NewDB(conf config.DBConf) (*gorm.DB, error) {
}

func NewTestDB() (*gorm.DB, error) {
return OpenSQLite("file::memory:?cache=shared", &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true})
return OpenSQLite("file::memory:?cache=shared", &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
NamingStrategy: schema.NamingStrategy{
TablePrefix: "atk_",
},
})
}

func CloseDB(db *gorm.DB) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/entity/page.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Page struct {
//
// And must caution that the db query statement quoted the field name `key` with backticks.
// Different db may have different rules. The pgsql is not backticks, but double quotes.
// So use the pages.key (without any quotes) to instead of `key`.
// So use the pages.key (without any quotes) to instead of `key` (Mind the prefix table name).
//
// Consider to rename this column and make a db migration in the future.
Key string `gorm:"index;size:255"` // Page key
Expand Down
14 changes: 10 additions & 4 deletions server/handler/page_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,16 @@ func PageList(app *core.App, router fiber.Router) {
}

// Search
q = q.Scopes(func(d *gorm.DB) *gorm.DB {
return d.Where("LOWER(pages.key) LIKE LOWER(?) OR LOWER(title) LIKE LOWER(?)",
"%"+p.Search+"%", "%"+p.Search+"%")
})
if p.Search != "" {
q = q.Scopes(func(d *gorm.DB) *gorm.DB {
// Because historical reasons, the naming of this field named `key` does not follow best practices.
// In some database, directly use the field name `key` will cause an error.
// So must keep the table name before the field name.
tbPages := app.Dao().GetTableName(&entity.Page{})
return d.Where("LOWER("+tbPages+".key) LIKE LOWER(?) OR LOWER(title) LIKE LOWER(?)",
"%"+p.Search+"%", "%"+p.Search+"%")
})
}

// Total count
var total int64
Expand Down
8 changes: 6 additions & 2 deletions server/handler/stat.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,13 @@ func Stat(app *core.App, router fiber.Router) {
// ------------------------------------
// Comment most pages
// ------------------------------------
tbPages := app.Dao().GetTableName(&entity.Page{})
tbComments := app.Dao().GetTableName(&entity.Comment{})

var pages []entity.Page
app.Dao().DB().Raw(
"SELECT * FROM pages p WHERE p.site_name = ? ORDER BY (SELECT COUNT(*) FROM comments c WHERE c.page_key = p.key AND c.is_pending = ?) DESC LIMIT ?",
"SELECT * FROM "+tbPages+" p WHERE p.site_name = ? ORDER BY ("+
"SELECT COUNT(*) FROM "+tbComments+" c WHERE c.page_key = p.key AND c.is_pending = ?) DESC LIMIT ?",
p.SiteName, false, p.Limit,
).Find(&pages)

Expand Down Expand Up @@ -152,7 +156,7 @@ func Stat(app *core.App, router fiber.Router) {
// Query Site total PV
// ------------------------------------
var pv int64
app.Dao().DB().Raw("SELECT SUM(pv) FROM pages WHERE site_name = ?", p.SiteName).Row().Scan(&pv)
app.Dao().DB().Model(&entity.Page{}).Where(&entity.Page{SiteName: p.SiteName}).Select("SUM(pv)").Scan(&pv)

return common.RespData(c, ResponseStat{
Data: pv,
Expand Down
10 changes: 6 additions & 4 deletions server/handler/user_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ func UserList(app *core.App, router fiber.Router) {
}

// Search
q = q.Scopes(func(d *gorm.DB) *gorm.DB {
return d.Where("LOWER(name) LIKE LOWER(?) OR LOWER(email) LIKE LOWER(?) OR badge_name = ? OR last_ip = ?",
"%"+p.Search+"%", "%"+p.Search+"%", p.Search, p.Search)
})
if p.Search != "" {
q = q.Scopes(func(d *gorm.DB) *gorm.DB {
return d.Where("LOWER(name) LIKE LOWER(?) OR LOWER(email) LIKE LOWER(?) OR badge_name = ? OR last_ip = ?",
"%"+p.Search+"%", "%"+p.Search+"%", p.Search, p.Search)
})
}

// Total count
var total int64
Expand Down
6 changes: 5 additions & 1 deletion test/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/go-testfixtures/testfixtures/v3"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)

type TestApp struct {
Expand Down Expand Up @@ -48,7 +49,10 @@ func NewTestApp() (*TestApp, error) {

// open a sqlite db
dbInstance, err := gorm.Open(sqlite.Open(dbFile), &gorm.Config{
Logger: db_logger.New(),
Logger: db_logger.New(),
NamingStrategy: schema.NamingStrategy{
TablePrefix: "atk_", // Test table prefix, fixture filenames should match this
},
DisableForeignKeyConstraintWhenMigrating: true,
})
if err != nil {
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 4eb3dba

Please sign in to comment.