Skip to content

Commit

Permalink
🚧
Browse files Browse the repository at this point in the history
  • Loading branch information
ff137 committed Aug 15, 2024
1 parent 4d9f164 commit 2529ac7
Show file tree
Hide file tree
Showing 13 changed files with 433 additions and 316 deletions.
6 changes: 6 additions & 0 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ async def query(
*,
limit: Optional[int] = None,
offset: Optional[int] = None,
order_by: Optional[str] = None,
descending: bool = False,
post_filter_positive: dict = None,
post_filter_negative: dict = None,
alt: bool = False,
Expand Down Expand Up @@ -327,11 +329,15 @@ async def query(
tag_query=tag_query,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
)
else:
rows = await storage.find_all_records(
type_filter=cls.RECORD_TYPE,
tag_query=tag_query,
order_by=order_by,
descending=descending,
)

num_results_post_filter = 0 # used if applying pagination post-filter
Expand Down
31 changes: 26 additions & 5 deletions aries_cloudagent/messaging/models/paginated_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from aiohttp.web import BaseRequest
from marshmallow import fields
from marshmallow.validate import OneOf

from ...messaging.models.openapi import OpenAPISchema
from ...storage.base import DEFAULT_PAGE_SIZE, MAXIMUM_PAGE_SIZE
Expand Down Expand Up @@ -31,18 +32,38 @@ class PaginatedQuerySchema(OpenAPISchema):
metadata={"description": "Offset for pagination", "example": 0},
error_messages={"validator_failed": "Value must be 0 or greater"},
)
order_by = fields.Str(
required=False,
load_default=None,
dump_only=True, # Hide from schema by making it dump-only
load_only=True, # Ensure it can still be loaded/validated
validate=OneOf(["id"]), # Example of possible fields
metadata={"description": "Order results in descending order if true"},
error_messages={"validator_failed": "Ordering only support for column `id`"},
)
descending = fields.Bool(
required=False,
load_default=False,
metadata={"description": "Order results in descending order if true"},
)


def get_limit_offset(request: BaseRequest) -> Tuple[int, int]:
"""Read the limit and offset query parameters from a request as ints, with defaults.
def get_paginated_query_params(request: BaseRequest) -> Tuple[int, int, str, bool]:
"""Read the limit, offset, order_by, and descending query parameters from a request.
Args:
request: aiohttp request object
request: aiohttp request object.
Returns:
A tuple of the limit and offset values
A tuple containing:
- limit (int): The number of results to return, defaulting to DEFAULT_PAGE_SIZE.
- offset (int): The offset for pagination, defaulting to 0.
- order_by (str): The field by which to order results, defaulting to "id".
- descending (bool): Whether to order results in descending order, defaulting to False.
"""

limit = int(request.query.get("limit", DEFAULT_PAGE_SIZE))
offset = int(request.query.get("offset", 0))
return limit, offset
order_by = request.query.get("order_by", "id")
descending = bool(request.query.get("descending", False))
return limit, offset, order_by, descending
9 changes: 7 additions & 2 deletions aries_cloudagent/multitenant/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from ...core.profile import ProfileManagerProvider
from ...messaging.models.base import BaseModelError
from ...messaging.models.openapi import OpenAPISchema
from ...messaging.models.paginated_query import PaginatedQuerySchema, get_limit_offset
from ...messaging.models.paginated_query import (
PaginatedQuerySchema,
get_paginated_query_params,
)
from ...messaging.valid import UUID4_EXAMPLE, JSONWebToken
from ...multitenant.base import BaseMultitenantManager
from ...storage.error import StorageError, StorageNotFoundError
Expand Down Expand Up @@ -382,7 +385,7 @@ async def wallets_list(request: web.BaseRequest):
if wallet_name:
query["wallet_name"] = wallet_name

limit, offset = get_limit_offset(request)
limit, offset, order_by, descending = get_paginated_query_params(request)

try:
async with profile.session() as session:
Expand All @@ -391,6 +394,8 @@ async def wallets_list(request: web.BaseRequest):
tag_filter=query,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
)
results = [format_wallet_record(record) for record in records]
results.sort(key=lambda w: w["created_at"])
Expand Down
9 changes: 7 additions & 2 deletions aries_cloudagent/protocols/connections/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from ....connections.models.conn_record import ConnRecord, ConnRecordSchema
from ....messaging.models.base import BaseModelError
from ....messaging.models.openapi import OpenAPISchema
from ....messaging.models.paginated_query import PaginatedQuerySchema, get_limit_offset
from ....messaging.models.paginated_query import (
PaginatedQuerySchema,
get_paginated_query_params,
)
from ....messaging.valid import (
ENDPOINT_EXAMPLE,
ENDPOINT_VALIDATE,
Expand Down Expand Up @@ -469,7 +472,7 @@ async def connections_list(request: web.BaseRequest):
if request.query.get("connection_protocol"):
post_filter["connection_protocol"] = request.query["connection_protocol"]

limit, offset = get_limit_offset(request)
limit, offset, order_by, descending = get_paginated_query_params(request)

profile = context.profile
try:
Expand All @@ -479,6 +482,8 @@ async def connections_list(request: web.BaseRequest):
tag_filter,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
post_filter_positive=post_filter,
alt=True,
)
Expand Down
9 changes: 7 additions & 2 deletions aries_cloudagent/protocols/issue_credential/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from ....messaging.credential_definitions.util import CRED_DEF_TAGS
from ....messaging.models.base import BaseModelError
from ....messaging.models.openapi import OpenAPISchema
from ....messaging.models.paginated_query import PaginatedQuerySchema, get_limit_offset
from ....messaging.models.paginated_query import (
PaginatedQuerySchema,
get_paginated_query_params,
)
from ....messaging.valid import (
INDY_CRED_DEF_ID_EXAMPLE,
INDY_CRED_DEF_ID_VALIDATE,
Expand Down Expand Up @@ -404,7 +407,7 @@ async def credential_exchange_list(request: web.BaseRequest):
if request.query.get(k, "") != ""
}

limit, offset = get_limit_offset(request)
limit, offset, order_by, descending = get_paginated_query_params(request)

try:
async with context.profile.session() as session:
Expand All @@ -413,6 +416,8 @@ async def credential_exchange_list(request: web.BaseRequest):
tag_filter=tag_filter,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
post_filter_positive=post_filter,
)
results = [record.serialize() for record in records]
Expand Down
9 changes: 7 additions & 2 deletions aries_cloudagent/protocols/issue_credential/v2_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from ....messaging.decorators.attach_decorator import AttachDecorator
from ....messaging.models.base import BaseModelError
from ....messaging.models.openapi import OpenAPISchema
from ....messaging.models.paginated_query import PaginatedQuerySchema, get_limit_offset
from ....messaging.models.paginated_query import (
PaginatedQuerySchema,
get_paginated_query_params,
)
from ....messaging.valid import (
INDY_CRED_DEF_ID_EXAMPLE,
INDY_CRED_DEF_ID_VALIDATE,
Expand Down Expand Up @@ -568,7 +571,7 @@ async def credential_exchange_list(request: web.BaseRequest):
if request.query.get(k, "") != ""
}

limit, offset = get_limit_offset(request)
limit, offset, order_by, descending = get_paginated_query_params(request)

try:
async with profile.session() as session:
Expand All @@ -577,6 +580,8 @@ async def credential_exchange_list(request: web.BaseRequest):
tag_filter=tag_filter,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
post_filter_positive=post_filter,
)

Expand Down
9 changes: 7 additions & 2 deletions aries_cloudagent/protocols/present_proof/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from ....messaging.decorators.attach_decorator import AttachDecorator
from ....messaging.models.base import BaseModelError
from ....messaging.models.openapi import OpenAPISchema
from ....messaging.models.paginated_query import PaginatedQuerySchema, get_limit_offset
from ....messaging.models.paginated_query import (
PaginatedQuerySchema,
get_paginated_query_params,
)
from ....messaging.valid import (
INDY_EXTRA_WQL_EXAMPLE,
INDY_EXTRA_WQL_VALIDATE,
Expand Down Expand Up @@ -309,7 +312,7 @@ async def presentation_exchange_list(request: web.BaseRequest):
if request.query.get(k, "") != ""
}

limit, offset = get_limit_offset(request)
limit, offset, order_by, descending = get_paginated_query_params(request)

try:
async with context.profile.session() as session:
Expand All @@ -318,6 +321,8 @@ async def presentation_exchange_list(request: web.BaseRequest):
tag_filter=tag_filter,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
post_filter_positive=post_filter,
)
results = [record.serialize() for record in records]
Expand Down
9 changes: 7 additions & 2 deletions aries_cloudagent/protocols/present_proof/v2_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from ....messaging.decorators.attach_decorator import AttachDecorator
from ....messaging.models.base import BaseModelError
from ....messaging.models.openapi import OpenAPISchema
from ....messaging.models.paginated_query import PaginatedQuerySchema, get_limit_offset
from ....messaging.models.paginated_query import (
PaginatedQuerySchema,
get_paginated_query_params,
)
from ....messaging.valid import (
INDY_EXTRA_WQL_EXAMPLE,
INDY_EXTRA_WQL_VALIDATE,
Expand Down Expand Up @@ -448,7 +451,7 @@ async def present_proof_list(request: web.BaseRequest):
if request.query.get(k, "") != ""
}

limit, offset = get_limit_offset(request)
limit, offset, order_by, descending = get_paginated_query_params(request)

try:
async with profile.session() as session:
Expand All @@ -457,6 +460,8 @@ async def present_proof_list(request: web.BaseRequest):
tag_filter=tag_filter,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
post_filter_positive=post_filter,
)
results = [record.serialize() for record in records]
Expand Down
9 changes: 9 additions & 0 deletions aries_cloudagent/storage/askar.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ async def find_paginated_records(
tag_query: Mapping = None,
limit: int = DEFAULT_PAGE_SIZE,
offset: int = 0,
order_by: Optional[str] = None,
descending: bool = False,
) -> Sequence[StorageRecord]:
"""Retrieve a page of records matching a particular type filter and tag query.
Expand All @@ -182,6 +184,11 @@ async def find_paginated_records(
tag_query: An optional dictionary of tag filter clauses
limit: The maximum number of records to retrieve
offset: The offset to start retrieving records from
order_by: An optional field by which to order the records.
descending: Whether to order the records in descending order.
Returns:
A sequence of StorageRecord matching the filter and query parameters.
"""
results = []

Expand All @@ -190,6 +197,8 @@ async def find_paginated_records(
tag_filter=tag_query,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
profile=self._session.profile.settings.get("wallet.askar_profile"),
):
results += (
Expand Down
9 changes: 8 additions & 1 deletion aries_cloudagent/storage/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Abstract base classes for non-secrets storage."""

from abc import ABC, abstractmethod
from typing import Mapping, Sequence
from typing import Mapping, Optional, Sequence

from .error import StorageDuplicateError, StorageError, StorageNotFoundError
from .record import StorageRecord
Expand Down Expand Up @@ -96,6 +96,8 @@ async def find_paginated_records(
tag_query: Mapping = None,
limit: int = DEFAULT_PAGE_SIZE,
offset: int = 0,
order_by: Optional[str] = None,
descending: bool = False,
) -> Sequence[StorageRecord]:
"""Retrieve a page of records matching a particular type filter and tag query.
Expand All @@ -104,6 +106,11 @@ async def find_paginated_records(
tag_query: An optional dictionary of tag filter clauses
limit: The maximum number of records to retrieve
offset: The offset to start retrieving records from
order_by: An optional field by which to order the records.
descending: Whether to order the records in descending order.
Returns:
A sequence of StorageRecord matching the filter and query parameters.
"""

@abstractmethod
Expand Down
46 changes: 30 additions & 16 deletions aries_cloudagent/storage/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Basic in-memory storage implementation (non-wallet)."""

from typing import Mapping, Sequence
from operator import attrgetter
from typing import Mapping, Optional, Sequence

from ..core.in_memory import InMemoryProfile
from .base import (
Expand Down Expand Up @@ -103,6 +104,8 @@ async def find_paginated_records(
tag_query: Mapping = None,
limit: int = DEFAULT_PAGE_SIZE,
offset: int = 0,
order_by: Optional[str] = None,
descending: bool = False,
) -> Sequence[StorageRecord]:
"""Retrieve a page of records matching a particular type filter and tag query.
Expand All @@ -111,21 +114,30 @@ async def find_paginated_records(
tag_query: An optional dictionary of tag filter clauses
limit: The maximum number of records to retrieve
offset: The offset to start retrieving records from
order_by: An optional field by which to order the records.
descending: Whether to order the records in descending order.
Returns:
A sequence of StorageRecord matching the filter and query parameters.
"""
results = []
skipped = 0
collected = 0
for record in self.profile.records.values():
if record.type == type_filter and tag_query_match(record.tags, tag_query):
if skipped < offset:
skipped += 1
continue
if collected < limit:
collected += 1
results.append(record)
else:
break
return results
# Filter records based on type and tag_query
filtered_records = [
record
for record in self.profile.records.values()
if record.type == type_filter and tag_query_match(record.tags, tag_query)
]

# Sort records if order_by is specified
if order_by:
try:
filtered_records.sort(key=attrgetter(order_by), reverse=descending)
except AttributeError:
raise ValueError(f"Invalid order_by field: {order_by}")

# Apply pagination (offset and limit)
paginated_records = filtered_records[offset : offset + limit]

return paginated_records

async def find_all_records(
self,
Expand Down Expand Up @@ -309,7 +321,9 @@ async def fetch(self, max_count: int = None) -> Sequence[StorageRecord]:
except StopIteration:
break
record = self._cache[id]
if record.type == check_type and tag_query_match(record.tags, self.tag_query):
if record.type == check_type and tag_query_match(
record.tags, self.tag_query
):
ret.append(record)
i -= 1

Expand Down
Loading

0 comments on commit 2529ac7

Please sign in to comment.