Skip to content

Commit

Permalink
Merge branch 'main' into endorser_protocol_askar_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ianco authored Oct 22, 2021
2 parents 1efa692 + 984fa19 commit fe278e0
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 61 deletions.
2 changes: 1 addition & 1 deletion aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def setup(self):

# Register all outbound transports
self.outbound_transport_manager = OutboundTransportManager(
context, self.handle_not_delivered
self.root_profile, self.handle_not_delivered
)
await self.outbound_transport_manager.setup()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""V2.0 issue-credential linked data proof credential format handler."""

from ......vc.ld_proofs.error import LinkedDataProofException
from ......vc.ld_proofs.check import get_properties_without_context
import logging

from typing import Mapping
Expand Down Expand Up @@ -399,6 +401,17 @@ async def create_offer(
detail = LDProofVCDetail.deserialize(offer_data)
detail = await self._prepare_detail(detail)

document_loader = self.profile.inject(DocumentLoader)
missing_properties = get_properties_without_context(
detail.credential.serialize(), document_loader
)

if len(missing_properties) > 0:
raise LinkedDataProofException(
f"{len(missing_properties)} attributes dropped. "
f"Provide definitions in context to correct. {missing_properties}"
)

# Make sure we can issue with the did and proof type
await self._assert_can_issue_with_id_and_proof_type(
detail.credential.issuer_id, detail.options.proof_type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from copy import deepcopy
from .......vc.ld_proofs.error import LinkedDataProofException
from asynctest import TestCase as AsyncTestCase
from asynctest import mock as async_mock
from unittest.mock import patch
from marshmallow import ValidationError

from .. import handler as test_module
Expand Down Expand Up @@ -406,7 +408,9 @@ async def test_create_offer(self):
LDProofCredFormatHandler,
"_assert_can_issue_with_id_and_proof_type",
async_mock.CoroutineMock(),
) as mock_can_issue:
) as mock_can_issue, patch.object(
test_module, "get_properties_without_context", return_value=[]
):
(cred_format, attachment) = await self.handler.create_offer(
self.cred_proposal
)
Expand Down Expand Up @@ -444,7 +448,7 @@ async def test_create_offer_adds_bbs_context(self):
LDProofCredFormatHandler,
"_assert_can_issue_with_id_and_proof_type",
async_mock.CoroutineMock(),
):
), patch.object(test_module, "get_properties_without_context", return_value=[]):
(cred_format, attachment) = await self.handler.create_offer(cred_proposal)

# assert BBS url added to context
Expand All @@ -457,6 +461,27 @@ async def test_create_offer_x_no_proposal(self):
context.exception
)

async def test_create_offer_x_wrong_attributes(self):
missing_properties = ["foo"]
with async_mock.patch.object(
LDProofCredFormatHandler,
"_assert_can_issue_with_id_and_proof_type",
async_mock.CoroutineMock(),
), patch.object(
test_module,
"get_properties_without_context",
return_value=missing_properties,
), self.assertRaises(
LinkedDataProofException
) as context:
await self.handler.create_offer(self.cred_proposal)

assert (
f"{len(missing_properties)} attributes dropped. "
f"Provide definitions in context to correct. {missing_properties}"
in str(context.exception)
)

async def test_receive_offer(self):
cred_ex_record = async_mock.MagicMock()
cred_offer_message = async_mock.MagicMock()
Expand Down
3 changes: 3 additions & 0 deletions aries_cloudagent/protocols/issue_credential/v2_0/routes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Credential exchange admin routes."""

from ....vc.ld_proofs.error import LinkedDataProofException
from json.decoder import JSONDecodeError
from typing import Mapping

Expand Down Expand Up @@ -1060,6 +1061,8 @@ async def credential_exchange_send_bound_offer(request: web.BaseRequest):
cred_ex_record,
outbound_handler,
)
except LinkedDataProofException as err:
raise web.HTTPBadRequest(reason=err) from err

await outbound_handler(cred_offer_message, connection_id=connection_id)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .....vc.ld_proofs.error import LinkedDataProofException
from asynctest import mock as async_mock, TestCase as AsyncTestCase

from .....admin.request_context import AdminRequestContext
Expand Down Expand Up @@ -802,6 +803,37 @@ async def test_credential_exchange_send_bound_offer(self):

mock_response.assert_called_once_with(mock_cx_rec.serialize.return_value)

async def test_credential_exchange_send_bound_offer_linked_data_error(self):
self.request.json = async_mock.CoroutineMock(return_value={})
self.request.match_info = {"cred_ex_id": "dummy"}

with async_mock.patch.object(
test_module, "ConnRecord", autospec=True
) as mock_conn_rec, async_mock.patch.object(
test_module, "V20CredManager", autospec=True
) as mock_cred_mgr, async_mock.patch.object(
test_module, "V20CredExRecord", autospec=True
) as mock_cx_rec_cls, async_mock.patch.object(
test_module.web, "json_response"
) as mock_response:
mock_cx_rec_cls.retrieve_by_id = async_mock.CoroutineMock()
mock_cx_rec_cls.retrieve_by_id.return_value.state = (
test_module.V20CredExRecord.STATE_PROPOSAL_RECEIVED
)

mock_cred_mgr.return_value.create_offer = async_mock.CoroutineMock()

mock_cx_rec = async_mock.MagicMock()

exception_message = "ex"
mock_cred_mgr.return_value.create_offer.side_effect = (
LinkedDataProofException(exception_message)
)
with self.assertRaises(test_module.web.HTTPBadRequest) as error:
await test_module.credential_exchange_send_bound_offer(self.request)

assert exception_message in str(error.exception)

async def test_credential_exchange_send_bound_offer_bad_cred_ex_id(self):
self.request.json = async_mock.CoroutineMock(return_value={})
self.request.match_info = {"cred_ex_id": "dummy"}
Expand Down
5 changes: 4 additions & 1 deletion aries_cloudagent/transport/outbound/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
class BaseOutboundTransport(ABC):
"""Base outbound transport class."""

def __init__(self, wire_format: BaseWireFormat = None) -> None:
def __init__(
self, wire_format: BaseWireFormat = None, root_profile: Profile = None
) -> None:
"""Initialize a `BaseOutboundTransport` instance."""
self._collector = None
self._wire_format = wire_format
self.root_profile = root_profile

@property
def collector(self) -> Collector:
Expand Down
4 changes: 2 additions & 2 deletions aries_cloudagent/transport/outbound/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class HttpTransport(BaseOutboundTransport):

schemes = ("http", "https")

def __init__(self) -> None:
def __init__(self, **kwargs) -> None:
"""Initialize an `HttpTransport` instance."""
super().__init__()
super().__init__(**kwargs)
self.client_session: ClientSession = None
self.connector: TCPConnector = None
self.logger = logging.getLogger(__name__)
Expand Down
33 changes: 17 additions & 16 deletions aries_cloudagent/transport/outbound/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from urllib.parse import urlparse

from ...connections.models.connection_target import ConnectionTarget
from ...config.injection_context import InjectionContext
from ...core.profile import Profile
from ...utils.classloader import ClassLoader, ModuleLoadError, ClassNotFoundError
from ...utils.stats import Collector
Expand Down Expand Up @@ -68,18 +67,16 @@ class OutboundTransportManager:

MAX_RETRY_COUNT = 4

def __init__(
self, context: InjectionContext, handle_not_delivered: Callable = None
):
def __init__(self, profile: Profile, handle_not_delivered: Callable = None):
"""
Initialize a `OutboundTransportManager` instance.
Args:
context: The application context
root_profile: The application root profile
handle_not_delivered: An optional handler for undelivered messages
"""
self.context = context
self.root_profile = profile
self.loop = asyncio.get_event_loop()
self.handle_not_delivered = handle_not_delivered
self.outbound_buffer = []
Expand All @@ -90,13 +87,15 @@ def __init__(
self.running_transports = {}
self.task_queue = TaskQueue(max_active=200)
self._process_task: asyncio.Task = None
if self.context.settings.get("transport.max_outbound_retry"):
self.MAX_RETRY_COUNT = self.context.settings["transport.max_outbound_retry"]
if self.root_profile.settings.get("transport.max_outbound_retry"):
self.MAX_RETRY_COUNT = self.root_profile.settings[
"transport.max_outbound_retry"
]

async def setup(self):
"""Perform setup operations."""
outbound_transports = (
self.context.settings.get("transport.outbound_configs") or []
self.root_profile.settings.get("transport.outbound_configs") or []
)
for outbound_transport in outbound_transports:
self.register(outbound_transport)
Expand Down Expand Up @@ -172,8 +171,10 @@ def register_class(

async def start_transport(self, transport_id: str):
"""Start a registered transport."""
transport = self.registered_transports[transport_id]()
transport.collector = self.context.inject_or(Collector)
transport = self.registered_transports[transport_id](
root_profile=self.root_profile
)
transport.collector = self.root_profile.inject_or(Collector)
await transport.start()
self.running_transports[transport_id] = transport

Expand Down Expand Up @@ -379,14 +380,14 @@ async def _process_loop(self):
if deliver:
queued.state = QueuedOutboundMessage.STATE_DELIVER
p_time = trace_event(
self.context.settings,
self.root_profile.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.DELIVER.START."
+ queued.endpoint,
)
self.deliver_queued_message(queued)
trace_event(
self.context.settings,
self.root_profile.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.DELIVER.END."
+ queued.endpoint,
Expand All @@ -408,13 +409,13 @@ async def _process_loop(self):
else:
queued.state = QueuedOutboundMessage.STATE_ENCODE
p_time = trace_event(
self.context.settings,
self.root_profile.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.ENCODE.START",
)
self.encode_queued_message(queued)
trace_event(
self.context.settings,
self.root_profile.settings,
queued.message if queued.message else queued.payload,
outcome="OutboundTransportManager.ENCODE.END",
perf_counter=p_time,
Expand Down Expand Up @@ -449,7 +450,7 @@ async def perform_encode(
self, queued: QueuedOutboundMessage, wire_format: BaseWireFormat = None
):
"""Perform message encoding."""
wire_format = wire_format or self.context.inject(BaseWireFormat)
wire_format = wire_format or self.root_profile.inject(BaseWireFormat)

session = await queued.profile.session()
queued.payload = await wire_format.encode_message(
Expand Down
Loading

0 comments on commit fe278e0

Please sign in to comment.