diff --git a/aries_cloudagent/storage/in_memory.py b/aries_cloudagent/storage/in_memory.py index f8d1f8f844..c891539ac2 100644 --- a/aries_cloudagent/storage/in_memory.py +++ b/aries_cloudagent/storage/in_memory.py @@ -1,6 +1,7 @@ """Basic in-memory storage implementation (non-wallet).""" -from typing import Mapping, Sequence +from operator import attrgetter +from typing import Mapping, Optional, Sequence from ..core.in_memory import InMemoryProfile from .base import ( @@ -103,6 +104,8 @@ async def find_paginated_records( tag_query: Mapping = None, limit: int = DEFAULT_PAGE_SIZE, offset: int = 0, + order_by: Optional[str] = None, + descending: bool = False, ) -> Sequence[StorageRecord]: """Retrieve a page of records matching a particular type filter and tag query. @@ -111,26 +114,37 @@ async def find_paginated_records( tag_query: An optional dictionary of tag filter clauses limit: The maximum number of records to retrieve offset: The offset to start retrieving records from + order_by: An optional field by which to order the records. + descending: Whether to order the records in descending order. + + Returns: + A sequence of StorageRecord matching the filter and query parameters. """ - results = [] - skipped = 0 - collected = 0 - for record in self.profile.records.values(): - if record.type == type_filter and tag_query_match(record.tags, tag_query): - if skipped < offset: - skipped += 1 - continue - if collected < limit: - collected += 1 - results.append(record) - else: - break - return results + # Filter records based on type and tag_query + filtered_records = [ + record + for record in self.profile.records.values() + if record.type == type_filter and tag_query_match(record.tags, tag_query) + ] + + # Sort records if order_by is specified + if order_by: + try: + filtered_records.sort(key=attrgetter(order_by), reverse=descending) + except AttributeError: + raise ValueError(f"Invalid order_by field: {order_by}") + + # Apply pagination (offset and limit) + paginated_records = filtered_records[offset : offset + limit] + + return paginated_records async def find_all_records( self, type_filter: str, tag_query: Mapping = None, + order_by: Optional[str] = None, + descending: bool = False, options: Mapping = None, ): """Retrieve all records matching a particular type filter and tag query.""" @@ -138,6 +152,14 @@ async def find_all_records( for record in self.profile.records.values(): if record.type == type_filter and tag_query_match(record.tags, tag_query): results.append(record) + + # Sort records if order_by is specified + if order_by: + try: + results.sort(key=attrgetter(order_by), reverse=descending) + except AttributeError: + raise ValueError(f"Invalid order_by field: {order_by}") + return results async def delete_all_records(