From 9867e35d779531eff6b33b615c04a9fba20eaea2 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Fri, 4 Feb 2022 19:02:55 -0800 Subject: [PATCH] support saving storage records with a predetermined ID (for tests); add for_update option to retrieval methods Signed-off-by: Andrew Whitehead --- .../messaging/models/base_record.py | 20 ++++++++++++++----- aries_cloudagent/storage/askar.py | 15 +++++++++----- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/aries_cloudagent/messaging/models/base_record.py b/aries_cloudagent/messaging/models/base_record.py index 07c56b2008..44ed51e4fe 100644 --- a/aries_cloudagent/messaging/models/base_record.py +++ b/aries_cloudagent/messaging/models/base_record.py @@ -89,6 +89,7 @@ def __init__( *, created_at: Union[str, datetime] = None, updated_at: Union[str, datetime] = None, + new_with_id: bool = False, ): """Initialize a new BaseRecord.""" if not self.RECORD_TYPE: @@ -99,6 +100,7 @@ def __init__( ) self._id = id self._last_state = state + self._new_with_id = new_with_id self.state = state self.created_at = datetime_to_str(created_at) self.updated_at = datetime_to_str(updated_at) @@ -218,7 +220,11 @@ async def clear_cached_key(cls, session: ProfileSession, cache_key: str): @classmethod async def retrieve_by_id( - cls: Type[RecordType], session: ProfileSession, record_id: str + cls: Type[RecordType], + session: ProfileSession, + record_id: str, + *, + for_update=False, ) -> RecordType: """ Retrieve a stored record by ID. @@ -230,7 +236,7 @@ async def retrieve_by_id( storage = session.inject(BaseStorage) result = await storage.get_record( - cls.RECORD_TYPE, record_id, {"retrieveTags": False} + cls.RECORD_TYPE, record_id, {"forUpdate": for_update, "retrieveTags": False} ) vals = json.loads(result.value) return cls.from_storage(record_id, vals) @@ -241,6 +247,8 @@ async def retrieve_by_tag_filter( session: ProfileSession, tag_filter: dict, post_filter: dict = None, + *, + for_update=False, ) -> RecordType: """ Retrieve a record by tag filter. @@ -256,7 +264,7 @@ async def retrieve_by_tag_filter( rows = await storage.find_all_records( cls.RECORD_TYPE, cls.prefix_tag_filter(tag_filter), - options={"retrieveTags": False}, + options={"forUpdate": for_update, "retrieveTags": False}, ) found = None for record in rows: @@ -349,15 +357,17 @@ async def save( try: self.updated_at = time_now() storage = session.inject(BaseStorage) - if self._id: + if self._id and not self._new_with_id: record = self.storage_record await storage.update_record(record, record.value, record.tags) new_record = False else: - self._id = str(uuid.uuid4()) + if not self._id: + self._id = str(uuid.uuid4()) self.created_at = self.updated_at await storage.add_record(self.storage_record) new_record = True + self._new_with_id = False finally: params = {self.RECORD_TYPE: self.serialize()} if log_params: diff --git a/aries_cloudagent/storage/askar.py b/aries_cloudagent/storage/askar.py index 75683d4af9..2a48cfb2a7 100644 --- a/aries_cloudagent/storage/askar.py +++ b/aries_cloudagent/storage/askar.py @@ -84,10 +84,11 @@ async def get_record( raise StorageError("Record type not provided") if not record_id: raise StorageError("Record ID not provided") - if not options: - options = {} + for_update = bool(options and options.get("forUpdate")) try: - item = await self._session.handle.fetch(record_type, record_id) + item = await self._session.handle.fetch( + record_type, record_id, for_update=for_update + ) except AskarError as err: raise StorageError("Error when fetching storage record") from err if not item: @@ -155,9 +156,10 @@ async def find_record( tag_query: Tags to query options: Dictionary of backend-specific options """ + for_update = bool(options and options.get("forUpdate")) try: results = await self._session.handle.fetch_all( - type_filter, tag_query, limit=2 + type_filter, tag_query, limit=2, for_update=for_update ) except AskarError as err: raise StorageError("Error when finding storage record") from err @@ -180,8 +182,11 @@ async def find_all_records( options: Mapping = None, ): """Retrieve all records matching a particular type filter and tag query.""" + for_update = bool(options and options.get("forUpdate")) results = [] - for row in await self._session.handle.fetch_all(type_filter, tag_query): + for row in await self._session.handle.fetch_all( + type_filter, tag_query, for_update=for_update + ): results.append( StorageRecord( type=row.category,