From 11fca8bd8d58fd12139f6dd06febd1198cdecba7 Mon Sep 17 00:00:00 2001 From: Manuel Doncel Martos Date: Tue, 12 Nov 2024 01:07:51 +0100 Subject: [PATCH] using clause approach --- internal/model.go | 32 --------------- model.go | 84 +++++++++++++++++++++++++++++---------- model_test.go | 71 +++++++++++++++++++++++++++++++++ pagorminator.go | 51 +++++------------------- pagorminator_test.go | 94 ++++++++++++++++++++------------------------ 5 files changed, 188 insertions(+), 144 deletions(-) delete mode 100644 internal/model.go create mode 100644 model_test.go diff --git a/internal/model.go b/internal/model.go deleted file mode 100644 index bab061a..0000000 --- a/internal/model.go +++ /dev/null @@ -1,32 +0,0 @@ -package internal - -type PageRequestImpl struct { - Page int - Size int - TotalPages int - TotalElements int -} - -func (p PageRequestImpl) GetOffset() int { - return (p.Page - 1) * p.Size -} - -func (p PageRequestImpl) GetPage() int { - return p.Page -} - -func (p PageRequestImpl) GetSize() int { - return p.Size -} - -func (p PageRequestImpl) GetTotalPages() int { - return p.TotalPages -} - -func (p PageRequestImpl) GetTotalElements() int { - return p.TotalElements -} - -func (p PageRequestImpl) IsUnPaged() bool { - return p.Size == 0 && p.Page == 0 -} diff --git a/model.go b/model.go index ee438cb..01741b3 100644 --- a/model.go +++ b/model.go @@ -2,42 +2,86 @@ package pagorminator import ( "errors" - "github.com/manuelarte/pagorminator/internal" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "math" ) +const pagorminatorClause = "pagorminator:clause" + var ( ErrPageCantBeNegative = errors.New("page number can't be negative") ErrSizeCantBeNegative = errors.New("size can't be negative") ErrSizeNotAllowed = errors.New("size is not allowed") ) -var _ PageRequest = internal.PageRequestImpl{} - -// PageRequest Struct that contains the pagination information -type PageRequest interface { - GetPage() int - GetSize() int - GetOffset() int - GetTotalPages() int - GetTotalElements() int - IsUnPaged() bool -} +var _ clause.Expression = new(Pagination) +var _ gorm.StatementModifier = new(Pagination) -// PageRequestOf Creates a PageRequest with the page and size values -func PageRequestOf(page, size int) (PageRequest, error) { +// PageRequest Create page to query the database +func PageRequest(page, size int) (Pagination, error) { if page < 0 { - return nil, ErrPageCantBeNegative + return Pagination{}, ErrPageCantBeNegative } if size < 0 { - return nil, ErrSizeCantBeNegative + return Pagination{}, ErrSizeCantBeNegative } if page > 0 && size == 0 { - return nil, ErrSizeNotAllowed + return Pagination{}, ErrSizeNotAllowed } - return &internal.PageRequestImpl{Page: page, Size: size}, nil + return Pagination{page: page, size: size}, nil } // UnPaged Create an unpaged request (no pagination is applied) -func UnPaged() PageRequest { - return &internal.PageRequestImpl{Page: 0, Size: 0} +func UnPaged() Pagination { + return Pagination{page: 0, size: 0} +} + +// Pagination Clause to apply pagination +type Pagination struct { + page int + size int + totalElements int64 +} + +func (p *Pagination) GetPage() int { + return p.page +} +func (p *Pagination) GetSize() int { + return p.size +} + +func (p *Pagination) GetOffset() int { + return (p.page - 1) * p.size +} + +func (p *Pagination) GetTotalPages() int { + if p.size > 0 { + return calculateTotalPages(p.totalElements, p.size) + } else { + return 1 + } +} + +func (p *Pagination) GetTotalElements() int64 { + return p.totalElements +} + +func (p *Pagination) IsUnPaged() bool { + return p.page == 0 && p.size == 0 +} + +func (p *Pagination) ModifyStatement(stm *gorm.Statement) { + db := stm.DB + db.Set(pagorminatorClause, p) + if !p.IsUnPaged() { + stm.DB.Limit(p.size).Offset((p.page - 1) * p.size) + } +} + +func (p *Pagination) Build(_ clause.Builder) { +} + +func calculateTotalPages(totalElements int64, size int) int { + return int(math.Ceil(float64(totalElements) / float64(size))) } diff --git a/model_test.go b/model_test.go new file mode 100644 index 0000000..95df7d5 --- /dev/null +++ b/model_test.go @@ -0,0 +1,71 @@ +package pagorminator + +import "testing" + +func TestPagination_UnPaged(t *testing.T) { + t.Parallel() + tests := map[string]struct { + page int + size int + expected bool + }{ + "page 0 size 0": { + page: 0, + size: 0, + expected: true, + }, + "page zero size not zero": { + page: 0, + size: 1, + expected: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + page, err := PageRequest(test.page, test.size) + if err != nil { + t.Errorf("Unexpected error: %s", err) + } + if page.IsUnPaged() != test.expected { + t.Errorf("IsUnPaged() expected %v, got %v", test.expected, page.IsUnPaged()) + } + }) + } +} + +func TestPagination_CalculateTotalPages(t *testing.T) { + t.Parallel() + tests := map[string]struct { + totalElements int64 + size int + expected int + }{ + "totalElements lower than size": { + totalElements: 2, + size: 4, + expected: 1, + }, + "totalElements greater and not divisible by size": { + totalElements: 3, + size: 2, + expected: 2, + }, + "totalElements greater and divisible by size": { + totalElements: 4, + size: 2, + expected: 2, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + actual := calculateTotalPages(test.totalElements, test.size) + if actual != test.expected { + t.Errorf("totalPages expected %v, got %v", test.expected, actual) + } + }) + } +} diff --git a/pagorminator.go b/pagorminator.go index 1a306ca..e4903db 100644 --- a/pagorminator.go +++ b/pagorminator.go @@ -1,7 +1,6 @@ package pagorminator import ( - "github.com/manuelarte/pagorminator/internal" "gorm.io/gorm" ) @@ -9,15 +8,9 @@ const ( countKey = "pagorminator.count" ) -func WithPagination(pageRequest PageRequest) func(*gorm.DB) *gorm.DB { - return func(db *gorm.DB) *gorm.DB { - return db.Set("pagorminator:pageRequest", pageRequest) - } -} - var _ gorm.Plugin = new(PaGormMinator) -// PaGormMinator Gorm plugin to add pagination to your queries +// PaGormMinator Gorm plugin to add total elements and total pages to your pagination query type PaGormMinator struct { } @@ -26,54 +19,30 @@ func (p PaGormMinator) Name() string { } func (p PaGormMinator) Initialize(db *gorm.DB) error { - err := db.Callback().Query().Before("gorm:query").Register("pagorminator:addPagination", p.addPagination) - if err != nil { - return err - } - err = db.Callback().Query().After("pagorminator:addPagination").Register("pagorminator:count", p.count) + err := db.Callback().Query().Before("gorm:query").Register("pagorminator:count", p.count) if err != nil { return err } return nil } -func (p PaGormMinator) addPagination(db *gorm.DB) { - if db.Statement.Schema != nil { - if pageRequest, ok := p.getPageRequest(db); ok { - if !pageRequest.IsUnPaged() { - db.Limit(pageRequest.GetSize()).Offset(pageRequest.GetOffset()) - } - } - - } -} - func (p PaGormMinator) count(db *gorm.DB) { if db.Statement.Schema != nil { - if pageRequest, ok := p.getPageRequest(db); ok { + if pageable, ok := p.getPageRequest(db); ok { if value, ok := db.Get(countKey); !ok || !value.(bool) { - casted, _ := pageRequest.(*internal.PageRequestImpl) - newDb := db.Session(&gorm.Session{NewDB: true}) newDb.Statement = db.Statement.Statement var totalElements int64 - tx := newDb.Debug().Set(countKey, true). - Model(newDb.Statement.Model) + tx := newDb.Set(countKey, true).Model(newDb.Statement.Model) if whereClause, existWhere := db.Statement.Clauses["WHERE"]; existWhere { tx.Where(whereClause.Expression) } tx.Count(&totalElements) if tx.Error != nil { - db.AddError(tx.Error) + _ = db.AddError(tx.Error) } else { - casted.TotalElements = int(totalElements) - if casted.IsUnPaged() { - casted.Page = 0 - casted.TotalPages = 1 - } else { - casted.TotalPages = int(totalElements) / casted.Size - } + pageable.totalElements = totalElements } } } @@ -81,10 +50,10 @@ func (p PaGormMinator) count(db *gorm.DB) { } } -func (p PaGormMinator) getPageRequest(db *gorm.DB) (PageRequest, bool) { - if value, ok := db.Get("pagorminator:pageRequest"); ok { - if pageRequest, ok := value.(PageRequest); ok { - return pageRequest, true +func (p PaGormMinator) getPageRequest(db *gorm.DB) (*Pagination, bool) { + if value, ok := db.Get(pagorminatorClause); ok { + if paginationClause, ok := value.(*Pagination); ok { + return paginationClause, true } } return nil, false diff --git a/pagorminator_test.go b/pagorminator_test.go index 1959b37..ccc4428 100644 --- a/pagorminator_test.go +++ b/pagorminator_test.go @@ -2,7 +2,6 @@ package pagorminator import ( "fmt" - "github.com/manuelarte/pagorminator/internal" "gorm.io/driver/sqlite" "gorm.io/gorm" "testing" @@ -15,21 +14,21 @@ type TestStruct struct { } func TestPaginationScopeMetadata_NoWhere(t *testing.T) { + t.Parallel() tests := map[string]struct { toMigrate []*TestStruct - pageRequest PageRequest - expectedPage PageRequest + pageRequest Pagination + expectedPage Pagination }{ "UnPaged one item": { toMigrate: []*TestStruct{ {Code: "1"}, }, pageRequest: UnPaged(), - expectedPage: &internal.PageRequestImpl{ - Page: 0, - Size: 0, - TotalElements: 1, - TotalPages: 1, + expectedPage: Pagination{ + page: 0, + size: 0, + totalElements: 1, }, }, "UnPaged several items": { @@ -37,11 +36,10 @@ func TestPaginationScopeMetadata_NoWhere(t *testing.T) { {Code: "1", Price: 1}, {Code: "2", Price: 2}, }, pageRequest: UnPaged(), - expectedPage: &internal.PageRequestImpl{ - Page: 0, - Size: 0, - TotalElements: 2, - TotalPages: 1, + expectedPage: Pagination{ + page: 0, + size: 0, + totalElements: 2, }, }, "Paged 1/2 items": { @@ -49,11 +47,10 @@ func TestPaginationScopeMetadata_NoWhere(t *testing.T) { {Code: "1", Price: 1}, {Code: "2", Price: 2}, }, pageRequest: mustPageRequestOf(1, 1), - expectedPage: &internal.PageRequestImpl{ - Page: 1, - Size: 1, - TotalElements: 2, - TotalPages: 2, + expectedPage: Pagination{ + page: 1, + size: 1, + totalElements: 2, }, }, } @@ -66,7 +63,7 @@ func TestPaginationScopeMetadata_NoWhere(t *testing.T) { // Read var products []*TestStruct - db.Scopes(WithPagination(test.pageRequest)).Find(&products) // find product with integer primary key + db.Clauses(&test.pageRequest).Find(&products) // find product with integer primary key if !equalPageRequests(test.pageRequest, test.expectedPage) { t.Fatalf("expected page to be %d, got %d", test.expectedPage, test.pageRequest) } @@ -75,11 +72,12 @@ func TestPaginationScopeMetadata_NoWhere(t *testing.T) { } func TestPaginationScopeMetadata_Where(t *testing.T) { + t.Parallel() tests := map[string]struct { toMigrate []*TestStruct - pageRequest PageRequest + pageRequest Pagination where string - expectedPage PageRequest + expectedPage Pagination }{ "UnPaged one item, not filtered": { toMigrate: []*TestStruct{ @@ -87,11 +85,10 @@ func TestPaginationScopeMetadata_Where(t *testing.T) { }, pageRequest: UnPaged(), where: "price < 100", - expectedPage: &internal.PageRequestImpl{ - Page: 0, - Size: 0, - TotalElements: 1, - TotalPages: 1, + expectedPage: Pagination{ + page: 0, + size: 0, + totalElements: 1, }, }, "UnPaged one item, filtered out": { @@ -100,11 +97,10 @@ func TestPaginationScopeMetadata_Where(t *testing.T) { }, pageRequest: UnPaged(), where: "price > 100", - expectedPage: &internal.PageRequestImpl{ - Page: 0, - Size: 0, - TotalElements: 0, - TotalPages: 1, + expectedPage: Pagination{ + page: 0, + size: 0, + totalElements: 0, }, }, "UnPaged two items, one filtered out": { @@ -113,11 +109,10 @@ func TestPaginationScopeMetadata_Where(t *testing.T) { }, pageRequest: UnPaged(), where: "price > 50", - expectedPage: &internal.PageRequestImpl{ - Page: 0, - Size: 0, - TotalElements: 1, - TotalPages: 1, + expectedPage: Pagination{ + page: 0, + size: 0, + totalElements: 1, }, }, "Paged four items, two filtered out": { @@ -127,11 +122,10 @@ func TestPaginationScopeMetadata_Where(t *testing.T) { }, pageRequest: mustPageRequestOf(0, 1), where: "price > 50", - expectedPage: &internal.PageRequestImpl{ - Page: 0, - Size: 1, - TotalElements: 2, - TotalPages: 2, + expectedPage: Pagination{ + page: 0, + size: 1, + totalElements: 2, }, }, } @@ -144,7 +138,7 @@ func TestPaginationScopeMetadata_Where(t *testing.T) { // Read var products []*TestStruct - db.Debug().Scopes(WithPagination(test.pageRequest)).Where(test.where).Find(&products) // find product with integer primary key + db.Clauses(&test.pageRequest).Where(test.where).Find(&products) if !equalPageRequests(test.pageRequest, test.expectedPage) { t.Fatalf("expected page to be %d, got %d", test.expectedPage, test.pageRequest) } @@ -171,16 +165,14 @@ func setupDb(t *testing.T, name string) *gorm.DB { return db } -func mustPageRequestOf(page, size int) PageRequest { - toReturn, _ := PageRequestOf(page, size) +func mustPageRequestOf(page, size int) Pagination { + toReturn, _ := PageRequest(page, size) return toReturn } -func equalPageRequests(p1, p2 PageRequest) bool { - casted1 := p1.(*internal.PageRequestImpl) - casted2 := p2.(*internal.PageRequestImpl) - return casted1.Page == casted2.Page && - casted1.Size == casted2.Size && - casted1.TotalElements == casted2.TotalElements && - casted1.TotalPages == casted2.TotalPages +func equalPageRequests(p1, p2 Pagination) bool { + return p1.page == p2.page && + p1.size == p2.size && + p1.totalElements == p2.totalElements && + p1.GetTotalPages() == p2.GetTotalPages() }