Skip to content

Commit

Permalink
Merge branch 'main' into feature/configurable-route-access
Browse files Browse the repository at this point in the history
  • Loading branch information
swcurran authored Jul 18, 2022
2 parents b94d833 + ad1cad3 commit 6bf8384
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 73 deletions.
1 change: 1 addition & 0 deletions aries_cloudagent/admin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"acapy::actionmenu::received": "actionmenu",
"acapy::actionmenu::get-active-menu": "get-active-menu",
"acapy::actionmenu::perform-menu-action": "perform-menu-action",
"acapy::keylist::updated": "keylist",
}


Expand Down
3 changes: 2 additions & 1 deletion aries_cloudagent/core/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Optional,
Pattern,
TYPE_CHECKING,
Tuple,
)
from functools import partial

Expand Down Expand Up @@ -193,7 +194,7 @@ class MockEventBus(EventBus):
def __init__(self):
"""Initialize MockEventBus."""
super().__init__()
self.events = []
self.events: List[Tuple[Profile, Event]] = []

async def notify(self, profile: "Profile", event: Event):
"""Append the event to MockEventBus.events."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
session, context.connection_record.connection_id
)
response = await mgr.update_keylist(record, updates=context.message.updates)
response.assign_thread_from(context.message)
await responder.send_reply(response)
except (StorageNotFoundError, MediationNotGrantedError):
reply = CMProblemReport(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
await mgr.store_update_results(
context.connection_record.connection_id, context.message.updated
)
await mgr.notify_keylist_updated(
context.connection_record.connection_id, context.message
)
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ async def test_handler(self):
handler, responder = KeylistUpdateResponseHandler(), MockResponder()
with async_mock.patch.object(
MediationManager, "store_update_results"
) as mock_method:
) as mock_store, async_mock.patch.object(
MediationManager, "notify_keylist_updated"
) as mock_notify:
await handler.handle(self.context, responder)
mock_method.assert_called_once_with(TEST_CONN_ID, self.updated)
mock_store.assert_called_once_with(TEST_CONN_ID, self.updated)
mock_notify.assert_called_once_with(TEST_CONN_ID, self.context.message)
119 changes: 68 additions & 51 deletions aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MediationManager:
SET_TO_DEFAULT_ON_GRANTED = "set_to_default_on_granted"
METADATA_KEY = "mediation"
METADATA_ID = "id"
KEYLIST_UPDATED_EVENT = "acapy::keylist::updated"

def __init__(self, profile: Profile):
"""Initialize Mediation Manager.
Expand Down Expand Up @@ -534,63 +535,79 @@ async def store_update_results(
session: An active profile session
"""
session = await self._profile.session()
to_save: Sequence[RouteRecord] = []
to_remove: Sequence[RouteRecord] = []
for updated in results:
if updated.result != KeylistUpdated.RESULT_SUCCESS:
# TODO better handle different results?
LOGGER.warning(
"Keylist update failure: %s(%s): %s",
updated.action,
updated.recipient_key,
updated.result,
)
continue
if updated.action == KeylistUpdateRule.RULE_ADD:
# Multi-tenancy uses route record for internal relaying of wallets
# So the record could already exist. We update in that case
try:
record = await RouteRecord.retrieve_by_recipient_key(
session, updated.recipient_key
)
record.connection_id = connection_id
record.role = RouteRecord.ROLE_CLIENT
except StorageNotFoundError:
record = RouteRecord(
role=RouteRecord.ROLE_CLIENT,
recipient_key=updated.recipient_key,
connection_id=connection_id,
)
to_save.append(record)
elif updated.action == KeylistUpdateRule.RULE_REMOVE:
try:
records = await RouteRecord.query(
session,
{
"role": RouteRecord.ROLE_CLIENT,
"connection_id": connection_id,
"recipient_key": updated.recipient_key,
},
)
except StorageNotFoundError as err:
LOGGER.error(
"No route found while processing keylist update response: %s",
err,

async with self._profile.session() as session:
for updated in results:
if updated.result != KeylistUpdated.RESULT_SUCCESS:
# TODO better handle different results?
LOGGER.warning(
"Keylist update failure: %s(%s): %s",
updated.action,
updated.recipient_key,
updated.result,
)
else:
if len(records) > 1:
continue
if updated.action == KeylistUpdateRule.RULE_ADD:
# Multi-tenancy uses route record for internal relaying of wallets
# So the record could already exist. We update in that case
try:
record = await RouteRecord.retrieve_by_recipient_key(
session, updated.recipient_key
)
record.connection_id = connection_id
record.role = RouteRecord.ROLE_CLIENT
except StorageNotFoundError:
record = RouteRecord(
role=RouteRecord.ROLE_CLIENT,
recipient_key=updated.recipient_key,
connection_id=connection_id,
)
to_save.append(record)
elif updated.action == KeylistUpdateRule.RULE_REMOVE:
try:
records = await RouteRecord.query(
session,
{
"role": RouteRecord.ROLE_CLIENT,
"connection_id": connection_id,
"recipient_key": updated.recipient_key,
},
)
except StorageNotFoundError as err:
LOGGER.error(
f"Too many ({len(records)}) routes found "
"while processing keylist update response"
"No route found while processing keylist update response: %s",
err,
)
record = records[0]
to_remove.append(record)
else:
if len(records) > 1:
LOGGER.error(
f"Too many ({len(records)}) routes found "
"while processing keylist update response"
)
record = records[0]
to_remove.append(record)

for record_for_saving in to_save:
await record_for_saving.save(
session, reason="Route successfully added."
)
for record_for_removal in to_remove:
await record_for_removal.delete_record(session)

for record_for_saving in to_save:
await record_for_saving.save(session, reason="Route successfully added.")
for record_for_removal in to_remove:
await record_for_removal.delete_record(session)
async def notify_keylist_updated(
self, connection_id: str, response: KeylistUpdateResponse
):
"""Notify of keylist update response received."""
await self._profile.notify(
self.KEYLIST_UPDATED_EVENT,
{
"connection_id": connection_id,
"thread_id": response._thread_id,
"updated": [update.serialize() for update in response.updated],
},
)

async def get_my_keylist(
self, connection_id: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
"""Test MediationManager."""
import logging

import pytest
from typing import AsyncIterable, Iterable

from asynctest import mock as async_mock
import pytest

from .. import manager as test_module
from .....core.event_bus import EventBus, MockEventBus
from .....core.in_memory import InMemoryProfile
from .....core.profile import Profile, ProfileSession
from .....connections.models.conn_record import ConnRecord
from .....messaging.request_context import RequestContext
from .....storage.error import StorageNotFoundError
from .....transport.inbound.receipt import MessageReceipt

from ....routing.v1_0.models.route_record import RouteRecord

from .. import manager as test_module
from ..manager import (
MediationAlreadyExists,
MediationManager,
Expand All @@ -22,12 +19,14 @@
)
from ..messages.inner.keylist_update_rule import KeylistUpdateRule
from ..messages.inner.keylist_updated import KeylistUpdated
from ..messages.keylist_update_response import KeylistUpdateResponse
from ..messages.mediate_deny import MediationDeny
from ..messages.mediate_grant import MediationGrant
from ..messages.mediate_request import MediationRequest
from ..models.mediation_record import MediationRecord

TEST_CONN_ID = "conn-id"
TEST_THREAD_ID = "thread-id"
TEST_ENDPOINT = "https://example.com"
TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx"
TEST_ROUTE_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC"
Expand All @@ -36,29 +35,32 @@


@pytest.fixture
async def profile() -> Profile:
def profile() -> Iterable[Profile]:
"""Fixture for profile used in tests."""
# pylint: disable=W0621
context = RequestContext.test_context()
context.message_receipt = MessageReceipt(sender_verkey=TEST_VERKEY)
context.connection_record = ConnRecord(connection_id=TEST_CONN_ID)
yield context.profile
yield InMemoryProfile.test_profile(bind={EventBus: MockEventBus()})


@pytest.fixture
def mock_event_bus(profile: Profile):
yield profile.inject(EventBus)


@pytest.fixture
async def session(profile) -> ProfileSession: # pylint: disable=W0621
async def session(profile) -> AsyncIterable[ProfileSession]: # pylint: disable=W0621
"""Fixture for profile session used in tests."""
yield await profile.session()
async with profile.session() as session:
yield session


@pytest.fixture
async def manager(profile) -> MediationManager: # pylint: disable=W0621
def manager(profile) -> Iterable[MediationManager]: # pylint: disable=W0621
"""Fixture for manager used in tests."""
yield MediationManager(profile)


@pytest.fixture
def record() -> MediationRecord:
def record() -> Iterable[MediationRecord]:
"""Fixture for record used in tests."""
yield MediationRecord(
state=MediationRecord.STATE_GRANTED, connection_id=TEST_CONN_ID
Expand All @@ -71,7 +73,7 @@ class TestMediationManager: # pylint: disable=R0904,W0621
async def test_create_manager_no_profile(self):
"""test_create_manager_no_profile."""
with pytest.raises(MediationManagerError):
await MediationManager(None)
MediationManager(None)

async def test_create_did(self, manager, session):
"""test_create_did."""
Expand Down Expand Up @@ -363,7 +365,11 @@ async def test_add_remove_key_mix(self, manager):
assert update.updates[0].recipient_key == TEST_VERKEY
assert update.updates[1].recipient_key == TEST_ROUTE_VERKEY

async def test_store_update_results(self, session, manager):
async def test_store_update_results(
self,
session: ProfileSession,
manager: MediationManager,
):
"""test_store_update_results."""
await RouteRecord(
role=RouteRecord.ROLE_CLIENT,
Expand Down Expand Up @@ -467,6 +473,34 @@ async def test_store_update_results_errors(self, caplog, manager):
assert "server_error" in caplog.text
print(caplog.text)

async def test_notify_keylist_updated(
self, manager: MediationManager, mock_event_bus: MockEventBus
):
"""test notify_keylist_updated."""
response = KeylistUpdateResponse(
updated=[
KeylistUpdated(
recipient_key=TEST_ROUTE_VERKEY,
action=KeylistUpdateRule.RULE_ADD,
result=KeylistUpdated.RESULT_SUCCESS,
),
KeylistUpdated(
recipient_key=TEST_VERKEY,
action=KeylistUpdateRule.RULE_REMOVE,
result=KeylistUpdated.RESULT_SUCCESS,
),
],
)
response.assign_thread_id(TEST_THREAD_ID)
await manager.notify_keylist_updated(TEST_CONN_ID, response)
assert mock_event_bus.events
assert mock_event_bus.events[0][1].topic == manager.KEYLIST_UPDATED_EVENT
assert mock_event_bus.events[0][1].payload == {
"connection_id": TEST_CONN_ID,
"thread_id": TEST_THREAD_ID,
"updated": [result.serialize() for result in response.updated],
}

async def test_get_my_keylist(self, session, manager):
"""test_get_my_keylist."""
await RouteRecord(
Expand Down

0 comments on commit 6bf8384

Please sign in to comment.