Skip to content

Commit

Permalink
support saving storage records with a predetermined ID (for tests); a…
Browse files Browse the repository at this point in the history
…dd for_update option to retrieval methods

Signed-off-by: Andrew Whitehead <[email protected]>
  • Loading branch information
andrewwhitehead committed Feb 8, 2022
1 parent bf9563c commit 9867e35
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
20 changes: 15 additions & 5 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
*,
created_at: Union[str, datetime] = None,
updated_at: Union[str, datetime] = None,
new_with_id: bool = False,
):
"""Initialize a new BaseRecord."""
if not self.RECORD_TYPE:
Expand All @@ -99,6 +100,7 @@ def __init__(
)
self._id = id
self._last_state = state
self._new_with_id = new_with_id
self.state = state
self.created_at = datetime_to_str(created_at)
self.updated_at = datetime_to_str(updated_at)
Expand Down Expand Up @@ -218,7 +220,11 @@ async def clear_cached_key(cls, session: ProfileSession, cache_key: str):

@classmethod
async def retrieve_by_id(
cls: Type[RecordType], session: ProfileSession, record_id: str
cls: Type[RecordType],
session: ProfileSession,
record_id: str,
*,
for_update=False,
) -> RecordType:
"""
Retrieve a stored record by ID.
Expand All @@ -230,7 +236,7 @@ async def retrieve_by_id(

storage = session.inject(BaseStorage)
result = await storage.get_record(
cls.RECORD_TYPE, record_id, {"retrieveTags": False}
cls.RECORD_TYPE, record_id, {"forUpdate": for_update, "retrieveTags": False}
)
vals = json.loads(result.value)
return cls.from_storage(record_id, vals)
Expand All @@ -241,6 +247,8 @@ async def retrieve_by_tag_filter(
session: ProfileSession,
tag_filter: dict,
post_filter: dict = None,
*,
for_update=False,
) -> RecordType:
"""
Retrieve a record by tag filter.
Expand All @@ -256,7 +264,7 @@ async def retrieve_by_tag_filter(
rows = await storage.find_all_records(
cls.RECORD_TYPE,
cls.prefix_tag_filter(tag_filter),
options={"retrieveTags": False},
options={"forUpdate": for_update, "retrieveTags": False},
)
found = None
for record in rows:
Expand Down Expand Up @@ -349,15 +357,17 @@ async def save(
try:
self.updated_at = time_now()
storage = session.inject(BaseStorage)
if self._id:
if self._id and not self._new_with_id:
record = self.storage_record
await storage.update_record(record, record.value, record.tags)
new_record = False
else:
self._id = str(uuid.uuid4())
if not self._id:
self._id = str(uuid.uuid4())
self.created_at = self.updated_at
await storage.add_record(self.storage_record)
new_record = True
self._new_with_id = False
finally:
params = {self.RECORD_TYPE: self.serialize()}
if log_params:
Expand Down
15 changes: 10 additions & 5 deletions aries_cloudagent/storage/askar.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ async def get_record(
raise StorageError("Record type not provided")
if not record_id:
raise StorageError("Record ID not provided")
if not options:
options = {}
for_update = bool(options and options.get("forUpdate"))
try:
item = await self._session.handle.fetch(record_type, record_id)
item = await self._session.handle.fetch(
record_type, record_id, for_update=for_update
)
except AskarError as err:
raise StorageError("Error when fetching storage record") from err
if not item:
Expand Down Expand Up @@ -155,9 +156,10 @@ async def find_record(
tag_query: Tags to query
options: Dictionary of backend-specific options
"""
for_update = bool(options and options.get("forUpdate"))
try:
results = await self._session.handle.fetch_all(
type_filter, tag_query, limit=2
type_filter, tag_query, limit=2, for_update=for_update
)
except AskarError as err:
raise StorageError("Error when finding storage record") from err
Expand All @@ -180,8 +182,11 @@ async def find_all_records(
options: Mapping = None,
):
"""Retrieve all records matching a particular type filter and tag query."""
for_update = bool(options and options.get("forUpdate"))
results = []
for row in await self._session.handle.fetch_all(type_filter, tag_query):
for row in await self._session.handle.fetch_all(
type_filter, tag_query, for_update=for_update
):
results.append(
StorageRecord(
type=row.category,
Expand Down

0 comments on commit 9867e35

Please sign in to comment.