Skip to content

Commit

Permalink
Merge pull request #2858 from dbluhm/fix/disclose-features-webhook
Browse files Browse the repository at this point in the history
fix: states for discovery record to emit webhook
  • Loading branch information
swcurran authored Mar 27, 2024
2 parents 41d8024 + 6ba81d3 commit 82d654e
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 107 deletions.
2 changes: 1 addition & 1 deletion aries_cloudagent/messaging/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def serialize(
) from err

@classmethod
def serde(cls, obj: Union["BaseModel", Mapping]) -> Optional[SerDe]:
def serde(cls, obj: Union["BaseModel", Mapping, None]) -> Optional[SerDe]:
"""Return serialized, deserialized representations of input object."""
if obj is None:
return None
Expand Down
154 changes: 80 additions & 74 deletions aries_cloudagent/protocols/discovery/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .messages.disclose import Disclose
from .messages.query import Query
from .models.discovery_record import V10DiscoveryExchangeRecord
from .models.discovery_record import V10DiscoveryExchangeRecord as DiscRecord


class V10DiscoveryMgrError(BaseError):
Expand Down Expand Up @@ -44,61 +44,64 @@ def profile(self) -> Profile:

async def receive_disclose(
self, disclose_msg: Disclose, connection_id: str
) -> V10DiscoveryExchangeRecord:
) -> DiscRecord:
"""Receive Disclose message and return updated V10DiscoveryExchangeRecord."""
if disclose_msg._thread:
thread_id = disclose_msg._thread.thid
try:
async with self._profile.session() as session:
discover_exch_rec = await V10DiscoveryExchangeRecord.retrieve_by_id(
async with self._profile.session() as session:
record = None
if disclose_msg._thread:
thread_id = disclose_msg._thread.thid
try:
record = await DiscRecord.retrieve_by_id(
session=session, record_id=thread_id
)
except StorageNotFoundError:
discover_exch_rec = await self.lookup_exchange_rec_by_connection(
connection_id
except StorageNotFoundError:
pass

if not record:
record = await DiscRecord.retrieve_if_exists_by_connection_id(
session, connection_id
)
if not discover_exch_rec:
discover_exch_rec = V10DiscoveryExchangeRecord()
else:
discover_exch_rec = await self.lookup_exchange_rec_by_connection(
connection_id
)
if not discover_exch_rec:
discover_exch_rec = V10DiscoveryExchangeRecord()
async with self._profile.session() as session:
discover_exch_rec.connection_id = connection_id
discover_exch_rec.disclose = disclose_msg
await discover_exch_rec.save(session)
return discover_exch_rec

if not record:
record = DiscRecord()

record.connection_id = connection_id
record.disclose = disclose_msg
record.state = DiscRecord.STATE_DISCLOSE_RECV
await record.save(session)

return record

async def lookup_exchange_rec_by_connection(
self, connection_id: str
) -> Optional[V10DiscoveryExchangeRecord]:
) -> Optional[DiscRecord]:
"""Retrieve V20DiscoveryExchangeRecord by connection_id."""
async with self._profile.session() as session:
if await V10DiscoveryExchangeRecord.exists_for_connection_id(
if await DiscRecord.exists_for_connection_id(
session=session, connection_id=connection_id
):
return await V10DiscoveryExchangeRecord.retrieve_by_connection_id(
return await DiscRecord.retrieve_by_connection_id(
session=session, connection_id=connection_id
)
else:
return None

async def receive_query(self, query_msg: Query) -> Disclose:
"""Process query and return the corresponding disclose message."""
registry = self._profile.context.inject_or(ProtocolRegistry)
registry = self._profile.context.inject(ProtocolRegistry)
query_str = query_msg.query
published_results = []
protocols = registry.protocols_matching_query(query_str)
results = await registry.prepare_disclosed(self._profile.context, protocols)

async with self._profile.session() as session:
to_publish_protocols = None
if (
session.settings.get("disclose_protocol_list")
and len(session.settings.get("disclose_protocol_list")) > 0
):
to_publish_protocols = session.settings.get("disclose_protocol_list")

for result in results:
to_publish_result = {}
if "pid" in result:
Expand All @@ -107,77 +110,80 @@ async def receive_query(self, query_msg: Query) -> Disclose:
and result.get("pid") not in to_publish_protocols
):
continue

to_publish_result["pid"] = result.get("pid")
else:
continue

if "roles" in result:
to_publish_result["roles"] = result.get("roles")

published_results.append(to_publish_result)

disclose_msg = Disclose(protocols=published_results)
# Check if query message has a thid
# If disclosing this agents feature
if query_msg._thread:
disclose_msg.assign_thread_id(query_msg._thread.thid)
return disclose_msg

async def check_if_disclosure_received(
self, record_id: str
) -> V10DiscoveryExchangeRecord:
async def check_if_disclosure_received(self, record_id: str) -> DiscRecord:
"""Check if disclosures has been received."""
while True:
async with self._profile.session() as session:
ex_rec = await V10DiscoveryExchangeRecord.retrieve_by_id(
ex_rec = await DiscRecord.retrieve_by_id(
session=session, record_id=record_id
)
if ex_rec.disclose:
return ex_rec
await asyncio.sleep(0.5)

async def create_and_send_query(
self, query: str, comment: str = None, connection_id: str = None
) -> V10DiscoveryExchangeRecord:
self,
query: str,
comment: Optional[str] = None,
connection_id: Optional[str] = None,
) -> DiscRecord:
"""Create and send a Query message."""
query_msg = Query(query=query, comment=comment)
if connection_id:
async with self._profile.session() as session:
# If existing record exists for a connection_id
if await V10DiscoveryExchangeRecord.exists_for_connection_id(
session=session, connection_id=connection_id
):
discovery_ex_rec = (
await V10DiscoveryExchangeRecord.retrieve_by_connection_id(
session=session, connection_id=connection_id
)
)
discovery_ex_rec.disclose = None
await discovery_ex_rec.save(session)
else:
discovery_ex_rec = V10DiscoveryExchangeRecord()
discovery_ex_rec.query_msg = query_msg
discovery_ex_rec.connection_id = connection_id
await discovery_ex_rec.save(session)
query_msg.assign_thread_id(discovery_ex_rec.discovery_exchange_id)
responder = self._profile.inject_or(BaseResponder)
if responder:
await responder.send(query_msg, connection_id=connection_id)
else:
self._logger.exception(
"Unable to send discover-features v1 query message"
": BaseResponder unavailable"
)
try:
return await asyncio.wait_for(
self.check_if_disclosure_received(
record_id=discovery_ex_rec.discovery_exchange_id,
),
5,
)
except asyncio.TimeoutError:
return discovery_ex_rec
else:
if not connection_id:
# Disclose this agent's features and/or goal codes
discovery_ex_rec = V10DiscoveryExchangeRecord()
discovery_ex_rec.query_msg = query_msg
record = DiscRecord()
record.query_msg = query_msg
disclose = await self.receive_query(query_msg=query_msg)
discovery_ex_rec.disclose = disclose
return discovery_ex_rec
record.disclose = disclose
return record

record = None
async with self._profile.session() as session:
record = await DiscRecord.retrieve_if_exists_by_connection_id(
session=session, connection_id=connection_id
)

if record:
record.disclose = None
else:
record = DiscRecord()

record.state = DiscRecord.STATE_QUERY_SENT
record.query_msg = query_msg
record.connection_id = connection_id

query_msg.assign_thread_id(record.discovery_exchange_id)
responder = self._profile.inject(BaseResponder)

await responder.send(query_msg, connection_id=connection_id)

async with self._profile.session() as session:
await record.save(session)

try:
return await asyncio.wait_for(
self.check_if_disclosure_received(
record_id=record.discovery_exchange_id,
),
5,
)

except asyncio.TimeoutError:
return record
5 changes: 4 additions & 1 deletion aries_cloudagent/protocols/discovery/v1_0/messages/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Represents a feature discovery query message."""

from typing import Optional
from marshmallow import EXCLUDE, fields

from .....messaging.agent_message import AgentMessage, AgentMessageSchema
Expand All @@ -22,7 +23,9 @@ class Meta:
message_type = QUERY
schema_class = "QuerySchema"

def __init__(self, *, query: str = None, comment: str = None, **kwargs):
def __init__(
self, *, query: Optional[str] = None, comment: Optional[str] = None, **kwargs
):
"""Initialize query message object.
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""."""

import logging
from typing import Any, Mapping, Union
from typing import Any, Mapping, Optional, Union

from marshmallow import fields

Expand Down Expand Up @@ -29,18 +29,24 @@ class Meta:
RECORD_TOPIC = "discover_feature"
TAG_NAMES = {"~thread_id" if UNENCRYPTED_TAGS else "thread_id", "connection_id"}

STATE_QUERY_SENT = "query-sent"
STATE_DISCLOSE_RECV = "disclose-received"

def __init__(
self,
*,
discovery_exchange_id: str = None,
connection_id: str = None,
thread_id: str = None,
query_msg: Union[Mapping, Query] = None,
disclose: Union[Mapping, Disclose] = None,
state: Optional[str] = None,
discovery_exchange_id: Optional[str] = None,
connection_id: Optional[str] = None,
thread_id: Optional[str] = None,
query_msg: Union[Mapping, Query, None] = None,
disclose: Union[Mapping, Disclose, None] = None,
**kwargs,
):
"""Initialize a new V10DiscoveryExchangeRecord."""
super().__init__(discovery_exchange_id, **kwargs)
super().__init__(
discovery_exchange_id, state or self.STATE_QUERY_SENT, **kwargs
)
self._id = discovery_exchange_id
self.connection_id = connection_id
self.thread_id = thread_id
Expand Down Expand Up @@ -80,6 +86,21 @@ async def retrieve_by_connection_id(
tag_filter = {"connection_id": connection_id}
return await cls.retrieve_by_tag_filter(session, tag_filter)

@classmethod
async def retrieve_if_exists_by_connection_id(
cls, session: ProfileSession, connection_id: str
) -> Optional["V10DiscoveryExchangeRecord"]:
"""Retrieve a discovery exchange record by connection."""
tag_filter = {"connection_id": connection_id}
result = await cls.query(session, tag_filter)
if len(result) > 1:
LOGGER.warning(
"More than one disclosure record found for connection: %s",
connection_id,
)

return result[0] if result else None

@classmethod
async def exists_for_connection_id(
cls, session: ProfileSession, connection_id: str
Expand Down
24 changes: 0 additions & 24 deletions aries_cloudagent/protocols/discovery/v1_0/tests/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio
import logging
import pytest

from aries_cloudagent.tests import mock
Expand Down Expand Up @@ -182,28 +180,6 @@ async def test_create_and_send_query_with_connection(self):
assert received_ex_rec.query_msg == return_ex_rec.query_msg
mock_send.assert_called_once()

async def test_create_and_send_query_with_connection_no_responder(self):
self.profile.context.injector.clear_binding(BaseResponder)
with mock.patch.object(
V10DiscoveryExchangeRecord,
"exists_for_connection_id",
mock.CoroutineMock(),
) as mock_exists_for_connection_id, mock.patch.object(
V10DiscoveryExchangeRecord,
"save",
mock.CoroutineMock(),
) as save_ex, mock.patch.object(
V10DiscoveryMgr, "check_if_disclosure_received", mock.CoroutineMock()
) as mock_disclosure_received:
self._caplog.set_level(logging.WARNING)
mock_exists_for_connection_id.return_value = False
mock_disclosure_received.side_effect = asyncio.TimeoutError
received_ex_rec = await self.manager.create_and_send_query(
query="*", connection_id="test123"
)
assert received_ex_rec.query_msg.query == "*"
assert "Unable to send discover-features v1" in self._caplog.text

async def test_create_and_send_query_with_no_connection(self):
with mock.patch.object(
V10DiscoveryMgr,
Expand Down

0 comments on commit 82d654e

Please sign in to comment.