Skip to content

Commit

Permalink
using clause approach
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelarte committed Nov 12, 2024
1 parent 3f53751 commit 11fca8b
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 144 deletions.
32 changes: 0 additions & 32 deletions internal/model.go

This file was deleted.

84 changes: 64 additions & 20 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,86 @@ package pagorminator

import (
"errors"
"github.com/manuelarte/pagorminator/internal"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"math"
)

const pagorminatorClause = "pagorminator:clause"

var (
ErrPageCantBeNegative = errors.New("page number can't be negative")
ErrSizeCantBeNegative = errors.New("size can't be negative")
ErrSizeNotAllowed = errors.New("size is not allowed")
)

var _ PageRequest = internal.PageRequestImpl{}

// PageRequest Struct that contains the pagination information
type PageRequest interface {
GetPage() int
GetSize() int
GetOffset() int
GetTotalPages() int
GetTotalElements() int
IsUnPaged() bool
}
var _ clause.Expression = new(Pagination)
var _ gorm.StatementModifier = new(Pagination)

// PageRequestOf Creates a PageRequest with the page and size values
func PageRequestOf(page, size int) (PageRequest, error) {
// PageRequest Create page to query the database
func PageRequest(page, size int) (Pagination, error) {
if page < 0 {
return nil, ErrPageCantBeNegative
return Pagination{}, ErrPageCantBeNegative
}
if size < 0 {
return nil, ErrSizeCantBeNegative
return Pagination{}, ErrSizeCantBeNegative
}
if page > 0 && size == 0 {
return nil, ErrSizeNotAllowed
return Pagination{}, ErrSizeNotAllowed
}
return &internal.PageRequestImpl{Page: page, Size: size}, nil
return Pagination{page: page, size: size}, nil
}

// UnPaged Create an unpaged request (no pagination is applied)
func UnPaged() PageRequest {
return &internal.PageRequestImpl{Page: 0, Size: 0}
func UnPaged() Pagination {
return Pagination{page: 0, size: 0}
}

// Pagination Clause to apply pagination
type Pagination struct {
page int
size int
totalElements int64
}

func (p *Pagination) GetPage() int {
return p.page
}
func (p *Pagination) GetSize() int {
return p.size
}

func (p *Pagination) GetOffset() int {
return (p.page - 1) * p.size
}

func (p *Pagination) GetTotalPages() int {
if p.size > 0 {
return calculateTotalPages(p.totalElements, p.size)
} else {
return 1
}
}

func (p *Pagination) GetTotalElements() int64 {
return p.totalElements
}

func (p *Pagination) IsUnPaged() bool {
return p.page == 0 && p.size == 0
}

func (p *Pagination) ModifyStatement(stm *gorm.Statement) {
db := stm.DB
db.Set(pagorminatorClause, p)
if !p.IsUnPaged() {
stm.DB.Limit(p.size).Offset((p.page - 1) * p.size)
}
}

func (p *Pagination) Build(_ clause.Builder) {
}

func calculateTotalPages(totalElements int64, size int) int {
return int(math.Ceil(float64(totalElements) / float64(size)))
}
71 changes: 71 additions & 0 deletions model_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package pagorminator

import "testing"

func TestPagination_UnPaged(t *testing.T) {
t.Parallel()
tests := map[string]struct {
page int
size int
expected bool
}{
"page 0 size 0": {
page: 0,
size: 0,
expected: true,
},
"page zero size not zero": {
page: 0,
size: 1,
expected: false,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
page, err := PageRequest(test.page, test.size)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
if page.IsUnPaged() != test.expected {
t.Errorf("IsUnPaged() expected %v, got %v", test.expected, page.IsUnPaged())
}
})
}
}

func TestPagination_CalculateTotalPages(t *testing.T) {
t.Parallel()
tests := map[string]struct {
totalElements int64
size int
expected int
}{
"totalElements lower than size": {
totalElements: 2,
size: 4,
expected: 1,
},
"totalElements greater and not divisible by size": {
totalElements: 3,
size: 2,
expected: 2,
},
"totalElements greater and divisible by size": {
totalElements: 4,
size: 2,
expected: 2,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
actual := calculateTotalPages(test.totalElements, test.size)
if actual != test.expected {
t.Errorf("totalPages expected %v, got %v", test.expected, actual)
}
})
}
}
51 changes: 10 additions & 41 deletions pagorminator.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
package pagorminator

import (
"github.com/manuelarte/pagorminator/internal"
"gorm.io/gorm"
)

const (
countKey = "pagorminator.count"
)

func WithPagination(pageRequest PageRequest) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Set("pagorminator:pageRequest", pageRequest)
}
}

var _ gorm.Plugin = new(PaGormMinator)

// PaGormMinator Gorm plugin to add pagination to your queries
// PaGormMinator Gorm plugin to add total elements and total pages to your pagination query
type PaGormMinator struct {
}

Expand All @@ -26,65 +19,41 @@ func (p PaGormMinator) Name() string {
}

func (p PaGormMinator) Initialize(db *gorm.DB) error {
err := db.Callback().Query().Before("gorm:query").Register("pagorminator:addPagination", p.addPagination)
if err != nil {
return err
}
err = db.Callback().Query().After("pagorminator:addPagination").Register("pagorminator:count", p.count)
err := db.Callback().Query().Before("gorm:query").Register("pagorminator:count", p.count)
if err != nil {
return err
}
return nil
}

func (p PaGormMinator) addPagination(db *gorm.DB) {
if db.Statement.Schema != nil {
if pageRequest, ok := p.getPageRequest(db); ok {
if !pageRequest.IsUnPaged() {
db.Limit(pageRequest.GetSize()).Offset(pageRequest.GetOffset())
}
}

}
}

func (p PaGormMinator) count(db *gorm.DB) {
if db.Statement.Schema != nil {
if pageRequest, ok := p.getPageRequest(db); ok {
if pageable, ok := p.getPageRequest(db); ok {
if value, ok := db.Get(countKey); !ok || !value.(bool) {
casted, _ := pageRequest.(*internal.PageRequestImpl)

newDb := db.Session(&gorm.Session{NewDB: true})
newDb.Statement = db.Statement.Statement

var totalElements int64
tx := newDb.Debug().Set(countKey, true).
Model(newDb.Statement.Model)
tx := newDb.Set(countKey, true).Model(newDb.Statement.Model)
if whereClause, existWhere := db.Statement.Clauses["WHERE"]; existWhere {
tx.Where(whereClause.Expression)
}
tx.Count(&totalElements)
if tx.Error != nil {
db.AddError(tx.Error)
_ = db.AddError(tx.Error)
} else {
casted.TotalElements = int(totalElements)
if casted.IsUnPaged() {
casted.Page = 0
casted.TotalPages = 1
} else {
casted.TotalPages = int(totalElements) / casted.Size
}
pageable.totalElements = totalElements
}
}
}

}
}

func (p PaGormMinator) getPageRequest(db *gorm.DB) (PageRequest, bool) {
if value, ok := db.Get("pagorminator:pageRequest"); ok {
if pageRequest, ok := value.(PageRequest); ok {
return pageRequest, true
func (p PaGormMinator) getPageRequest(db *gorm.DB) (*Pagination, bool) {
if value, ok := db.Get(pagorminatorClause); ok {
if paginationClause, ok := value.(*Pagination); ok {
return paginationClause, true
}
}
return nil, false
Expand Down
Loading

0 comments on commit 11fca8b

Please sign in to comment.