Skip to content

Commit

Permalink
refactor: Start DRY'ing filtered paginate code (#16099)
Browse files Browse the repository at this point in the history
Co-authored-by: testinginprod <[email protected]>
  • Loading branch information
ValarDragon and testinginprod authored May 15, 2023
1 parent e34a3e0 commit 660e906
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 174 deletions.
9 changes: 1 addition & 8 deletions types/query/collections_pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
predicateFunc func(key K, value V) (include bool, err error),
opts ...func(opt *CollectionsPaginateOptions[K]),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
if pageReq == nil {
pageReq = &PageRequest{}
}
pageReq = initPageRequestDefaults(pageReq)

offset := pageReq.Offset
key := pageReq.Key
Expand All @@ -67,11 +65,6 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

if limit == 0 {
limit = DefaultLimit
countTotal = true
}

var (
results []collections.KeyValue[K, V]
pageRes *PageResponse
Expand Down
215 changes: 84 additions & 131 deletions types/query/filtered_pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,106 +22,115 @@ func FilteredPaginate(
pageRequest *PageRequest,
onResult func(key, value []byte, accumulate bool) (bool, error),
) (*PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}

offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
reverse := pageRequest.Reverse
pageRequest = initPageRequestDefaults(pageRequest)

if offset > 0 && key != nil {
if pageRequest.Offset > 0 && pageRequest.Key != nil {
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

if limit == 0 {
limit = DefaultLimit

// count total results when the limit is zero/not supplied
countTotal = true
}

if len(key) != 0 {
iterator := getIterator(prefixStore, key, reverse)
defer iterator.Close()
var (
numHits uint64
nextKey []byte
err error
)

var (
numHits uint64
nextKey []byte
)
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
defer iterator.Close()

if len(pageRequest.Key) != 0 {
accumulateFn := func(_ uint64) bool { return true }
for ; iterator.Valid(); iterator.Next() {
if numHits == limit {
if numHits == pageRequest.Limit {
nextKey = iterator.Key()
break
}

if iterator.Error() != nil {
return nil, iterator.Error()
}

hit, err := onResult(iterator.Key(), iterator.Value(), true)
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
if err != nil {
return nil, err
}

if hit {
numHits++
}
}

return &PageResponse{
NextKey: nextKey,
}, nil
}

iterator := getIterator(prefixStore, nil, reverse)
defer iterator.Close()

end := offset + limit

var (
numHits uint64
nextKey []byte
)
end := pageRequest.Offset + pageRequest.Limit
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }

for ; iterator.Valid(); iterator.Next() {
if iterator.Error() != nil {
return nil, iterator.Error()
}

accumulate := numHits >= offset && numHits < end
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
if err != nil {
return nil, err
}

if hit {
numHits++
}

if numHits == end+1 {
if nextKey == nil {
nextKey = iterator.Key()
}

if !countTotal {
if !pageRequest.CountTotal {
break
}
}
}

res := &PageResponse{NextKey: nextKey}
if countTotal {
if pageRequest.CountTotal {
res.Total = numHits
}

return res, nil
}

func processResult(iterator types.Iterator, numHits uint64, onResult func(key, value []byte, accumulate bool) (bool, error), accumulateFn func(numHits uint64) bool) (uint64, error) {
if iterator.Error() != nil {
return numHits, iterator.Error()
}

accumulate := accumulateFn(numHits)
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
if err != nil {
return numHits, err
}

if hit {
numHits++
}

return numHits, nil
}

func genericProcessResult[T, F proto.Message](iterator types.Iterator, numHits uint64, onResult func(key []byte, value T) (F, error), accumulateFn func(numHits uint64) bool,
constructor func() T, cdc codec.BinaryCodec, results []F,
) ([]F, uint64, error) {
if iterator.Error() != nil {
return results, numHits, iterator.Error()
}

protoMsg := constructor()

err := cdc.Unmarshal(iterator.Value(), protoMsg)
if err != nil {
return results, numHits, err
}

val, err := onResult(iterator.Key(), protoMsg)
if err != nil {
return results, numHits, err
}

if proto.Size(val) != 0 {
// Previously this was the "accumulate" flag
if accumulateFn(numHits) {
results = append(results, val)
}
numHits++
}

return results, numHits, nil
}

// GenericFilteredPaginate does pagination of all the results in the PrefixStore based on the
// provided PageRequest. `onResult` should be used to filter or transform the results.
// `c` is a constructor function that needs to return a new instance of the type T (this is to
Expand All @@ -137,119 +146,63 @@ func GenericFilteredPaginate[T, F proto.Message](
onResult func(key []byte, value T) (F, error),
constructor func() T,
) ([]F, *PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}

offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
reverse := pageRequest.Reverse
pageRequest = initPageRequestDefaults(pageRequest)
results := []F{}

if offset > 0 && key != nil {
if pageRequest.Offset > 0 && pageRequest.Key != nil {
return results, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

if limit == 0 {
limit = DefaultLimit

// count total results when the limit is zero/not supplied
countTotal = true
}

if len(key) != 0 {
iterator := getIterator(prefixStore, key, reverse)
defer iterator.Close()
var (
numHits uint64
nextKey []byte
err error
)

var (
numHits uint64
nextKey []byte
)
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
defer iterator.Close()

if len(pageRequest.Key) != 0 {
accumulateFn := func(_ uint64) bool { return true }
for ; iterator.Valid(); iterator.Next() {
if numHits == limit {
if numHits == pageRequest.Limit {
nextKey = iterator.Key()
break
}

if iterator.Error() != nil {
return nil, nil, iterator.Error()
}

protoMsg := constructor()

err := cdc.Unmarshal(iterator.Value(), protoMsg)
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
if err != nil {
return nil, nil, err
}

val, err := onResult(iterator.Key(), protoMsg)
if err != nil {
return nil, nil, err
}

if proto.Size(val) != 0 {
results = append(results, val)
numHits++
}
}

return results, &PageResponse{
NextKey: nextKey,
}, nil
}

iterator := getIterator(prefixStore, nil, reverse)
defer iterator.Close()

end := offset + limit

var (
numHits uint64
nextKey []byte
)
end := pageRequest.Offset + pageRequest.Limit
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }

for ; iterator.Valid(); iterator.Next() {
if iterator.Error() != nil {
return nil, nil, iterator.Error()
}

protoMsg := constructor()

err := cdc.Unmarshal(iterator.Value(), protoMsg)
if err != nil {
return nil, nil, err
}

val, err := onResult(iterator.Key(), protoMsg)
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
if err != nil {
return nil, nil, err
}

if proto.Size(val) != 0 {
// Previously this was the "accumulate" flag
if numHits >= offset && numHits < end {
results = append(results, val)
}
numHits++
}

if numHits == end+1 {
if nextKey == nil {
nextKey = iterator.Key()
}

if !countTotal {
if !pageRequest.CountTotal {
break
}
}
}

res := &PageResponse{NextKey: nextKey}
if countTotal {
if pageRequest.CountTotal {
res.Total = numHits
}

Expand Down
Loading

0 comments on commit 660e906

Please sign in to comment.