Skip to content

Commit

Permalink
fix(query): filtered collections pagination (#16905)
Browse files Browse the repository at this point in the history
Co-authored-by: atheesh <[email protected]>
Co-authored-by: unknown unknown <unknown@unknown>
  • Loading branch information
3 people authored Jul 13, 2023
1 parent d868056 commit 9e098ca
Show file tree
Hide file tree
Showing 14 changed files with 228 additions and 119 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Ref: https://keepachangelog.com/en/1.0.0/

* (server) [#16827](https://github.com/cosmos/cosmos-sdk/pull/16827) Properly use `--trace` flag (before it was setting the trace level instead of displaying the stacktraces).
* (x/bank) [#16841](https://github.com/cosmos/cosmos-sdk/pull/16841) correctly process legacy `DenomAddressIndex` values.
* (types/query) [#16905](https://github.com/cosmos/cosmos-sdk/pull/16905) – Collections Pagination now applies proper count when filtering results.

### API Breaking Changes

Expand Down
104 changes: 76 additions & 28 deletions types/query/collections_pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,40 @@ type Collection[K, V any] interface {
KeyCodec() collcodec.KeyCodec[K]
}

// CollectionPaginate follows the same behavior as Paginate but works on a Collection.
func CollectionPaginate[K, V any, C Collection[K, V]](
// CollectionPaginate follows the same logic as Paginate but for collection types.
// transformFunc is used to transform the result to a different type.
func CollectionPaginate[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
pageReq *PageRequest,
) ([]collections.KeyValue[K, V], *PageResponse, error) {
return CollectionFilteredPaginate[K, V](ctx, coll, pageReq, nil)
transformFunc func(key K, value V) (T, error),
opts ...func(opt *CollectionsPaginateOptions[K]),
) ([]T, *PageResponse, error) {
return CollectionFilteredPaginate(
ctx,
coll,
pageReq,
nil,
transformFunc,
opts...,
)
}

// CollectionFilteredPaginate works in the same way as FilteredPaginate but for collection types.
// CollectionFilteredPaginate works in the same way as CollectionPaginate but allows to filter
// results using a predicateFunc.
// A nil predicateFunc means no filtering is applied and results are collected as is.
func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
// TransformFunc is applied only to results which are in range of the pagination and allow
// to convert the result to a different type.
// NOTE: do not collect results using the values/keys passed to predicateFunc as they are not
// guaranteed to be in the pagination range requested.
func CollectionFilteredPaginate[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
pageReq *PageRequest,
predicateFunc func(key K, value V) (include bool, err error),
transformFunc func(key K, value V) (T, error),
opts ...func(opt *CollectionsPaginateOptions[K]),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
) (results []T, pageRes *PageResponse, err error) {
pageReq = initPageRequestDefaults(pageReq)

offset := pageReq.Offset
Expand All @@ -65,12 +81,6 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

var (
results []collections.KeyValue[K, V]
pageRes *PageResponse
err error
)

opt := new(CollectionsPaginateOptions[K])
for _, o := range opts {
o(opt)
Expand All @@ -85,9 +95,9 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
}

if len(key) != 0 {
results, pageRes, err = collFilteredPaginateByKey(ctx, coll, prefix, key, reverse, limit, predicateFunc)
results, pageRes, err = collFilteredPaginateByKey(ctx, coll, prefix, key, reverse, limit, predicateFunc, transformFunc)
} else {
results, pageRes, err = collFilteredPaginateNoKey(ctx, coll, prefix, reverse, offset, limit, countTotal, predicateFunc)
results, pageRes, err = collFilteredPaginateNoKey(ctx, coll, prefix, reverse, offset, limit, countTotal, predicateFunc, transformFunc)
}
// invalid iter error is ignored to retain Paginate behavior
if errors.Is(err, collections.ErrInvalidIterator) {
Expand All @@ -102,7 +112,7 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](

// collFilteredPaginateNoKey applies the provided pagination on the collection when the starting key is not set.
// If predicateFunc is nil no filtering is applied.
func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
func collFilteredPaginateNoKey[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
prefix []byte,
Expand All @@ -111,7 +121,8 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
limit uint64,
countTotal bool,
predicateFunc func(K, V) (bool, error),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
transformFunc func(K, V) (T, error),
) ([]T, *PageResponse, error) {
iterator, err := getCollIter[K, V](ctx, coll, prefix, nil, reverse)
if err != nil {
return nil, nil, err
Expand All @@ -125,7 +136,7 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
var (
count uint64
nextKey []byte
results []collections.KeyValue[K, V]
results []T
)

for ; iterator.Valid(); iterator.Next() {
Expand All @@ -138,18 +149,28 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
}
// if no predicate function is specified then we just include the result
if predicateFunc == nil {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
count++

// if predicate function is defined we check if the result matches the filtering criteria
} else {
include, err := predicateFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
if include {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
count++
}
}
count++
// second case, we found all the objects specified within the limit
case count == limit:
key, err := iterator.Key()
Expand All @@ -172,12 +193,31 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
// but we need to count how many possible results exist in total.
// so we keep increasing the count until the iterator is fully consumed.
case count > limit:
count++
if predicateFunc == nil {
count++

// if predicate function is defined we check if the result matches the filtering criteria
} else {
kv, err := iterator.KeyValue()
if err != nil {
return nil, nil, err
}

include, err := predicateFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
if include {
count++
}
}
}
}

resp := &PageResponse{
NextKey: nextKey,
}

if countTotal {
resp.Total = count + offset
}
Expand All @@ -200,15 +240,16 @@ func advanceIter[I interface {

// collFilteredPaginateByKey paginates a collection when a starting key
// is provided in the PageRequest. Predicate is applied only if not nil.
func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
func collFilteredPaginateByKey[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
prefix []byte,
key []byte,
reverse bool,
limit uint64,
predicateFunc func(K, V) (bool, error),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
predicateFunc func(key K, value V) (bool, error),
transformFunc func(key K, value V) (transformed T, err error),
) (results []T, pageRes *PageResponse, err error) {
iterator, err := getCollIter[K, V](ctx, coll, prefix, key, reverse)
if err != nil {
return nil, nil, err
Expand All @@ -218,7 +259,6 @@ func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
var (
count uint64
nextKey []byte
results []collections.KeyValue[K, V]
)

for ; iterator.Valid(); iterator.Next() {
Expand All @@ -243,7 +283,11 @@ func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
}
// if no predicate is specified then we just append the result
if predicateFunc == nil {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
// if predicate is applied we execute the predicate function
// and append only if predicateFunc yields true.
} else {
Expand All @@ -252,7 +296,11 @@ func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
return nil, nil, err
}
if include {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
}
}
count++
Expand Down
13 changes: 11 additions & 2 deletions types/query/collections_pagination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ func TestCollectionPagination(t *testing.T) {
Limit: 3,
},
expResp: &PageResponse{
NextKey: encodeKey(3),
NextKey: encodeKey(5),
},
filter: func(key, value uint64) (bool, error) {
return key%2 == 0, nil
},
expResults: []collections.KeyValue[uint64, uint64]{
{Key: 0, Value: 0},
{Key: 2, Value: 2},
{Key: 4, Value: 4},
},
},
"filtered with key": {
Expand All @@ -131,7 +132,15 @@ func TestCollectionPagination(t *testing.T) {
for name, tc := range tcs {
tc := tc
t.Run(name, func(t *testing.T) {
gotResults, gotResponse, err := CollectionFilteredPaginate(ctx, m, tc.req, tc.filter)
gotResults, gotResponse, err := CollectionFilteredPaginate(
ctx,
m,
tc.req,
tc.filter,
func(key, value uint64) (collections.KeyValue[uint64, uint64], error) {
return collections.KeyValue[uint64, uint64]{Key: key, Value: value}, nil
},
)
if tc.wantErr != nil {
require.ErrorIs(t, err, tc.wantErr)
return
Expand Down
17 changes: 8 additions & 9 deletions x/auth/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ func (s queryServer) Accounts(ctx context.Context, req *types.QueryAccountsReque
return nil, status.Error(codes.InvalidArgument, "empty request")
}

var accounts []*codectypes.Any
_, pageRes, err := query.CollectionFilteredPaginate(ctx, s.k.Accounts, req.Pagination, func(_ sdk.AccAddress, value sdk.AccountI) (include bool, err error) {
accountAny, err := codectypes.NewAnyWithValue(value)
if err != nil {
return false, err
}
accounts = append(accounts, accountAny)
return false, nil // we don't include it since we're already appending the account
})
accounts, pageRes, err := query.CollectionPaginate(
ctx,
s.k.Accounts,
req.Pagination,
func(_ sdk.AccAddress, value sdk.AccountI) (*codectypes.Any, error) {
return codectypes.NewAnyWithValue(value)
},
)

return &types.QueryAccountsResponse{Accounts: accounts, Pagination: pageRes}, err
}
Expand Down
63 changes: 30 additions & 33 deletions x/bank/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,20 @@ func (k BaseKeeper) AllBalances(ctx context.Context, req *types.QueryAllBalances
}

sdkCtx := sdk.UnwrapSDKContext(ctx)

balances := sdk.NewCoins()

_, pageRes, err := query.CollectionFilteredPaginate(ctx, k.Balances, req.Pagination, func(key collections.Pair[sdk.AccAddress, string], value math.Int) (include bool, err error) {
denom := key.K2()
if req.ResolveDenom {
if metadata, ok := k.GetDenomMetaData(sdkCtx, denom); ok {
denom = metadata.Display
balances, pageRes, err := query.CollectionPaginate(
ctx,
k.Balances,
req.Pagination,
func(key collections.Pair[sdk.AccAddress, string], value math.Int) (sdk.Coin, error) {
if req.ResolveDenom {
if metadata, ok := k.GetDenomMetaData(sdkCtx, key.K2()); ok {
return sdk.NewCoin(metadata.Display, value), nil
}
}
}
balances = append(balances, sdk.NewCoin(denom, value))
return false, nil // we don't include results because we're appending them here.
}, query.WithCollectionPaginationPairPrefix[sdk.AccAddress, string](addr))
return sdk.NewCoin(key.K2(), value), nil
},
query.WithCollectionPaginationPairPrefix[sdk.AccAddress, string](addr),
)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "paginate: %v", err)
}
Expand All @@ -94,12 +95,10 @@ func (k BaseKeeper) SpendableBalances(ctx context.Context, req *types.QuerySpend

sdkCtx := sdk.UnwrapSDKContext(ctx)

balances := sdk.NewCoins()
zeroAmt := math.ZeroInt()

_, pageRes, err := query.CollectionFilteredPaginate(ctx, k.Balances, req.Pagination, func(key collections.Pair[sdk.AccAddress, string], _ math.Int) (include bool, err error) {
balances = append(balances, sdk.NewCoin(key.K2(), zeroAmt))
return false, nil // not including results as they're appended here
balances, pageRes, err := query.CollectionPaginate(ctx, k.Balances, req.Pagination, func(key collections.Pair[sdk.AccAddress, string], _ math.Int) (coin sdk.Coin, err error) {
return sdk.NewCoin(key.K2(), zeroAmt), nil
}, query.WithCollectionPaginationPairPrefix[sdk.AccAddress, string](addr))
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "paginate: %v", err)
Expand Down Expand Up @@ -280,19 +279,16 @@ func (k BaseKeeper) DenomOwners(
return nil, status.Error(codes.InvalidArgument, err.Error())
}

var denomOwners []*types.DenomOwner

_, pageRes, err := query.CollectionFilteredPaginate(goCtx, k.Balances.Indexes.Denom, req.Pagination,
func(key collections.Pair[string, sdk.AccAddress], value collections.NoValue) (include bool, err error) {
denomOwners, pageRes, err := query.CollectionPaginate(
goCtx,
k.Balances.Indexes.Denom,
req.Pagination,
func(key collections.Pair[string, sdk.AccAddress], value collections.NoValue) (*types.DenomOwner, error) {
amt, err := k.Balances.Get(goCtx, collections.Join(key.K2(), req.Denom))
if err != nil {
return false, err
return nil, err
}
denomOwners = append(denomOwners, &types.DenomOwner{
Address: key.K2().String(),
Balance: sdk.NewCoin(req.Denom, amt),
})
return false, nil
return &types.DenomOwner{Address: key.K2().String(), Balance: sdk.NewCoin(req.Denom, amt)}, nil
},
query.WithCollectionPaginationPairPrefix[string, sdk.AccAddress](req.Denom),
)
Expand All @@ -316,16 +312,17 @@ func (k BaseKeeper) SendEnabled(goCtx context.Context, req *types.QuerySendEnabl
}
}
} else {
results, pageResp, err := query.CollectionPaginate[string, bool](ctx, k.BaseViewKeeper.SendEnabled, req.Pagination)
results, pageResp, err := query.CollectionPaginate(
ctx,
k.BaseViewKeeper.SendEnabled,
req.Pagination, func(key string, value bool) (*types.SendEnabled, error) {
return types.NewSendEnabled(key, value), nil
},
)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
for _, r := range results {
resp.SendEnabled = append(resp.SendEnabled, &types.SendEnabled{
Denom: r.Key,
Enabled: r.Value,
})
}
resp.SendEnabled = results
resp.Pagination = pageResp
}

Expand Down
Loading

0 comments on commit 9e098ca

Please sign in to comment.