Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: states for discovery record to emit webhook #2858

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aries_cloudagent/messaging/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def serialize(
) from err

@classmethod
def serde(cls, obj: Union["BaseModel", Mapping]) -> Optional[SerDe]:
def serde(cls, obj: Union["BaseModel", Mapping, None]) -> Optional[SerDe]:
"""Return serialized, deserialized representations of input object."""
if obj is None:
return None
Expand Down
154 changes: 80 additions & 74 deletions aries_cloudagent/protocols/discovery/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .messages.disclose import Disclose
from .messages.query import Query
from .models.discovery_record import V10DiscoveryExchangeRecord
from .models.discovery_record import V10DiscoveryExchangeRecord as DiscRecord


class V10DiscoveryMgrError(BaseError):
Expand Down Expand Up @@ -44,61 +44,64 @@ def profile(self) -> Profile:

async def receive_disclose(
self, disclose_msg: Disclose, connection_id: str
) -> V10DiscoveryExchangeRecord:
) -> DiscRecord:
"""Receive Disclose message and return updated V10DiscoveryExchangeRecord."""
if disclose_msg._thread:
thread_id = disclose_msg._thread.thid
try:
async with self._profile.session() as session:
discover_exch_rec = await V10DiscoveryExchangeRecord.retrieve_by_id(
async with self._profile.session() as session:
record = None
if disclose_msg._thread:
thread_id = disclose_msg._thread.thid
try:
record = await DiscRecord.retrieve_by_id(
session=session, record_id=thread_id
)
except StorageNotFoundError:
discover_exch_rec = await self.lookup_exchange_rec_by_connection(
connection_id
except StorageNotFoundError:
pass

if not record:
record = await DiscRecord.retrieve_if_exists_by_connection_id(
session, connection_id
)
if not discover_exch_rec:
discover_exch_rec = V10DiscoveryExchangeRecord()
else:
discover_exch_rec = await self.lookup_exchange_rec_by_connection(
connection_id
)
if not discover_exch_rec:
discover_exch_rec = V10DiscoveryExchangeRecord()
async with self._profile.session() as session:
discover_exch_rec.connection_id = connection_id
discover_exch_rec.disclose = disclose_msg
await discover_exch_rec.save(session)
return discover_exch_rec

if not record:
record = DiscRecord()

record.connection_id = connection_id
record.disclose = disclose_msg
record.state = DiscRecord.STATE_DISCLOSE_RECV
await record.save(session)

return record

async def lookup_exchange_rec_by_connection(
self, connection_id: str
) -> Optional[V10DiscoveryExchangeRecord]:
) -> Optional[DiscRecord]:
"""Retrieve V20DiscoveryExchangeRecord by connection_id."""
async with self._profile.session() as session:
if await V10DiscoveryExchangeRecord.exists_for_connection_id(
if await DiscRecord.exists_for_connection_id(
session=session, connection_id=connection_id
):
return await V10DiscoveryExchangeRecord.retrieve_by_connection_id(
return await DiscRecord.retrieve_by_connection_id(
session=session, connection_id=connection_id
)
else:
return None

async def receive_query(self, query_msg: Query) -> Disclose:
"""Process query and return the corresponding disclose message."""
registry = self._profile.context.inject_or(ProtocolRegistry)
registry = self._profile.context.inject(ProtocolRegistry)
query_str = query_msg.query
published_results = []
protocols = registry.protocols_matching_query(query_str)
results = await registry.prepare_disclosed(self._profile.context, protocols)

async with self._profile.session() as session:
to_publish_protocols = None
if (
session.settings.get("disclose_protocol_list")
and len(session.settings.get("disclose_protocol_list")) > 0
):
to_publish_protocols = session.settings.get("disclose_protocol_list")

for result in results:
to_publish_result = {}
if "pid" in result:
Expand All @@ -107,77 +110,80 @@ async def receive_query(self, query_msg: Query) -> Disclose:
and result.get("pid") not in to_publish_protocols
):
continue

to_publish_result["pid"] = result.get("pid")
else:
continue

if "roles" in result:
to_publish_result["roles"] = result.get("roles")

published_results.append(to_publish_result)

disclose_msg = Disclose(protocols=published_results)
# Check if query message has a thid
# If disclosing this agents feature
if query_msg._thread:
disclose_msg.assign_thread_id(query_msg._thread.thid)
return disclose_msg

async def check_if_disclosure_received(
self, record_id: str
) -> V10DiscoveryExchangeRecord:
async def check_if_disclosure_received(self, record_id: str) -> DiscRecord:
"""Check if disclosures has been received."""
while True:
async with self._profile.session() as session:
ex_rec = await V10DiscoveryExchangeRecord.retrieve_by_id(
ex_rec = await DiscRecord.retrieve_by_id(
session=session, record_id=record_id
)
if ex_rec.disclose:
return ex_rec
await asyncio.sleep(0.5)

async def create_and_send_query(
self, query: str, comment: str = None, connection_id: str = None
) -> V10DiscoveryExchangeRecord:
self,
query: str,
comment: Optional[str] = None,
connection_id: Optional[str] = None,
) -> DiscRecord:
"""Create and send a Query message."""
query_msg = Query(query=query, comment=comment)
if connection_id:
async with self._profile.session() as session:
# If existing record exists for a connection_id
if await V10DiscoveryExchangeRecord.exists_for_connection_id(
session=session, connection_id=connection_id
):
discovery_ex_rec = (
await V10DiscoveryExchangeRecord.retrieve_by_connection_id(
session=session, connection_id=connection_id
)
)
discovery_ex_rec.disclose = None
await discovery_ex_rec.save(session)
else:
discovery_ex_rec = V10DiscoveryExchangeRecord()
discovery_ex_rec.query_msg = query_msg
discovery_ex_rec.connection_id = connection_id
await discovery_ex_rec.save(session)
query_msg.assign_thread_id(discovery_ex_rec.discovery_exchange_id)
responder = self._profile.inject_or(BaseResponder)
if responder:
await responder.send(query_msg, connection_id=connection_id)
else:
self._logger.exception(
"Unable to send discover-features v1 query message"
": BaseResponder unavailable"
)
try:
return await asyncio.wait_for(
self.check_if_disclosure_received(
record_id=discovery_ex_rec.discovery_exchange_id,
),
5,
)
except asyncio.TimeoutError:
return discovery_ex_rec
else:
if not connection_id:
# Disclose this agent's features and/or goal codes
discovery_ex_rec = V10DiscoveryExchangeRecord()
discovery_ex_rec.query_msg = query_msg
record = DiscRecord()
record.query_msg = query_msg
disclose = await self.receive_query(query_msg=query_msg)
discovery_ex_rec.disclose = disclose
return discovery_ex_rec
record.disclose = disclose
return record

record = None
async with self._profile.session() as session:
record = await DiscRecord.retrieve_if_exists_by_connection_id(
session=session, connection_id=connection_id
)

if record:
record.disclose = None
else:
record = DiscRecord()

record.state = DiscRecord.STATE_QUERY_SENT
record.query_msg = query_msg
record.connection_id = connection_id

query_msg.assign_thread_id(record.discovery_exchange_id)
responder = self._profile.inject(BaseResponder)

await responder.send(query_msg, connection_id=connection_id)

async with self._profile.session() as session:
await record.save(session)

try:
return await asyncio.wait_for(
self.check_if_disclosure_received(
record_id=record.discovery_exchange_id,
),
5,
)

except asyncio.TimeoutError:
return record
5 changes: 4 additions & 1 deletion aries_cloudagent/protocols/discovery/v1_0/messages/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Represents a feature discovery query message."""

from typing import Optional
from marshmallow import EXCLUDE, fields

from .....messaging.agent_message import AgentMessage, AgentMessageSchema
Expand All @@ -22,7 +23,9 @@ class Meta:
message_type = QUERY
schema_class = "QuerySchema"

def __init__(self, *, query: str = None, comment: str = None, **kwargs):
def __init__(
self, *, query: Optional[str] = None, comment: Optional[str] = None, **kwargs
):
"""Initialize query message object.

Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""."""

import logging
from typing import Any, Mapping, Union
from typing import Any, Mapping, Optional, Union

from marshmallow import fields

Expand Down Expand Up @@ -29,18 +29,24 @@ class Meta:
RECORD_TOPIC = "discover_feature"
TAG_NAMES = {"~thread_id" if UNENCRYPTED_TAGS else "thread_id", "connection_id"}

STATE_QUERY_SENT = "query-sent"
STATE_DISCLOSE_RECV = "disclose-received"

def __init__(
self,
*,
discovery_exchange_id: str = None,
connection_id: str = None,
thread_id: str = None,
query_msg: Union[Mapping, Query] = None,
disclose: Union[Mapping, Disclose] = None,
state: Optional[str] = None,
discovery_exchange_id: Optional[str] = None,
connection_id: Optional[str] = None,
thread_id: Optional[str] = None,
query_msg: Union[Mapping, Query, None] = None,
disclose: Union[Mapping, Disclose, None] = None,
**kwargs,
):
"""Initialize a new V10DiscoveryExchangeRecord."""
super().__init__(discovery_exchange_id, **kwargs)
super().__init__(
discovery_exchange_id, state or self.STATE_QUERY_SENT, **kwargs
)
self._id = discovery_exchange_id
self.connection_id = connection_id
self.thread_id = thread_id
Expand Down Expand Up @@ -80,6 +86,21 @@ async def retrieve_by_connection_id(
tag_filter = {"connection_id": connection_id}
return await cls.retrieve_by_tag_filter(session, tag_filter)

@classmethod
async def retrieve_if_exists_by_connection_id(
cls, session: ProfileSession, connection_id: str
) -> Optional["V10DiscoveryExchangeRecord"]:
"""Retrieve a discovery exchange record by connection."""
tag_filter = {"connection_id": connection_id}
result = await cls.query(session, tag_filter)
if len(result) > 1:
LOGGER.warning(
"More than one disclosure record found for connection: %s",
connection_id,
)

return result[0] if result else None

@classmethod
async def exists_for_connection_id(
cls, session: ProfileSession, connection_id: str
Expand Down
24 changes: 0 additions & 24 deletions aries_cloudagent/protocols/discovery/v1_0/tests/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio
import logging
import pytest

from aries_cloudagent.tests import mock
Expand Down Expand Up @@ -182,28 +180,6 @@ async def test_create_and_send_query_with_connection(self):
assert received_ex_rec.query_msg == return_ex_rec.query_msg
mock_send.assert_called_once()

async def test_create_and_send_query_with_connection_no_responder(self):
self.profile.context.injector.clear_binding(BaseResponder)
with mock.patch.object(
V10DiscoveryExchangeRecord,
"exists_for_connection_id",
mock.CoroutineMock(),
) as mock_exists_for_connection_id, mock.patch.object(
V10DiscoveryExchangeRecord,
"save",
mock.CoroutineMock(),
) as save_ex, mock.patch.object(
V10DiscoveryMgr, "check_if_disclosure_received", mock.CoroutineMock()
) as mock_disclosure_received:
self._caplog.set_level(logging.WARNING)
mock_exists_for_connection_id.return_value = False
mock_disclosure_received.side_effect = asyncio.TimeoutError
received_ex_rec = await self.manager.create_and_send_query(
query="*", connection_id="test123"
)
assert received_ex_rec.query_msg.query == "*"
assert "Unable to send discover-features v1" in self._caplog.text

async def test_create_and_send_query_with_no_connection(self):
with mock.patch.object(
V10DiscoveryMgr,
Expand Down