Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConnectionProblemReport handler #2600

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ async def delete_record(self, session: ProfileSession):

async def abandon(self, session: ProfileSession, *, reason: Optional[str] = None):
"""Set state to abandoned."""
reason = reason or "Connectin abandoned"
reason = reason or "Connection abandoned"
self.state = ConnRecord.State.ABANDONED.rfc160
self.error_msg = reason
await self.save(session, reason=reason)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
BaseResponder,
RequestContext,
)

from ..messages.connection_invitation import ConnectionInvitation
from ..messages.problem_report import ConnectionProblemReport, ProblemReportReason

Expand All @@ -25,8 +24,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
assert isinstance(context.message, ConnectionInvitation)

report = ConnectionProblemReport(
problem_code=ProblemReportReason.INVITATION_NOT_ACCEPTED,
explain="Connection invitations cannot be submitted via agent messaging",
description={
"code": ProblemReportReason.INVITATION_NOT_ACCEPTED.value,
"en": (
"Connection invitations cannot be submitted via agent messaging"
),
}
)
# client likely needs to be using direct responses to receive the problem report
await responder.send_reply(report)
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
"Error parsing DIDDoc for problem report"
)
await responder.send_reply(
ConnectionProblemReport(problem_code=e.error_code, explain=str(e)),
ConnectionProblemReport(
description={"en": e.message, "code": e.error_code}
),
target_list=targets,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
RequestContext,
)
from .....protocols.trustping.v1_0.messages.ping import Ping

from ..manager import ConnectionManager, ConnectionManagerError
from ..messages.connection_response import ConnectionResponse
from ..messages.problem_report import ConnectionProblemReport
Expand Down Expand Up @@ -46,7 +45,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
"Error parsing DIDDoc for problem report"
)
await responder.send_reply(
ConnectionProblemReport(problem_code=e.error_code, explain=str(e)),
ConnectionProblemReport(
description={"en": e.message, "code": e.error_code}
),
target_list=targets,
)
return
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Problem report handler for Connection Protocol."""

from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
HandlerException,
RequestContext,
)
from ..manager import ConnectionManager, ConnectionManagerError
from ..messages.problem_report import ConnectionProblemReport


class ConnectionProblemReportHandler(BaseHandler):
"""Handler class for Connection problem report messages."""

async def handle(self, context: RequestContext, responder: BaseResponder):
"""Handle problem report message."""
self._logger.debug(
f"ConnectionProblemReportHandler called with context {context}"
)
assert isinstance(context.message, ConnectionProblemReport)

self._logger.info(f"Received problem report: {context.message.problem_code}")
profile = context.profile
mgr = ConnectionManager(profile)
try:
if context.connection_record:
await mgr.receive_problem_report(
context.connection_record, context.message
)
else:
raise HandlerException("No connection established for problem report")
except ConnectionManagerError:
# Unrecognized problem report code
self._logger.exception("Error receiving Connection problem report")
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder
from ......transport.inbound.receipt import MessageReceipt

from ...handlers.connection_invitation_handler import ConnectionInvitationHandler
from ...messages.connection_invitation import ConnectionInvitation
from ...messages.problem_report import ConnectionProblemReport, ProblemReportReason
Expand All @@ -28,6 +27,10 @@ async def test_problem_report(self, request_context):
result, target = messages[0]
assert (
isinstance(result, ConnectionProblemReport)
and result.problem_code == ProblemReportReason.INVITATION_NOT_ACCEPTED
and result.description
and (
result.description["code"]
== ProblemReportReason.INVITATION_NOT_ACCEPTED.value
)
)
assert not target
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import pytest

from aries_cloudagent.tests import mock

from ......core.profile import ProfileSession
from ......connections.models import connection_target
from ......connections.models.conn_record import ConnRecord
from ......connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service
from ......core.profile import ProfileSession
from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder
from ......transport.inbound.receipt import MessageReceipt
from ......storage.base import BaseStorage
from ......storage.error import StorageNotFoundError
from ......transport.inbound.receipt import MessageReceipt
from ...handlers import connection_request_handler as handler
from ...manager import ConnectionManagerError
from ...messages.connection_request import ConnectionRequest
Expand Down Expand Up @@ -161,7 +162,7 @@ async def test_connection_record_without_mediation_metadata(
async def test_problem_report(self, mock_conn_mgr, request_context):
mock_conn_mgr.return_value.receive_request = mock.CoroutineMock()
mock_conn_mgr.return_value.receive_request.side_effect = ConnectionManagerError(
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value
)
request_context.message = ConnectionRequest()
handler_inst = handler.ConnectionRequestHandler()
Expand All @@ -172,7 +173,11 @@ async def test_problem_report(self, mock_conn_mgr, request_context):
result, target = messages[0]
assert (
isinstance(result, ConnectionProblemReport)
and result.problem_code == ProblemReportReason.REQUEST_NOT_ACCEPTED
and result.description
and (
result.description["code"]
== ProblemReportReason.REQUEST_NOT_ACCEPTED.value
)
)
assert target == {"target_list": None}

Expand All @@ -184,7 +189,7 @@ async def test_problem_report_did_doc(
):
mock_conn_mgr.return_value.receive_request = mock.CoroutineMock()
mock_conn_mgr.return_value.receive_request.side_effect = ConnectionManagerError(
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value
)
mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock(
return_value=[mock_conn_target]
Expand All @@ -202,7 +207,11 @@ async def test_problem_report_did_doc(
result, target = messages[0]
assert (
isinstance(result, ConnectionProblemReport)
and result.problem_code == ProblemReportReason.REQUEST_NOT_ACCEPTED
and result.description
and (
result.description["code"]
== ProblemReportReason.REQUEST_NOT_ACCEPTED.value
)
)
assert target == {"target_list": [mock_conn_target]}

Expand All @@ -214,7 +223,7 @@ async def test_problem_report_did_doc_no_conn_target(
):
mock_conn_mgr.return_value.receive_request = mock.CoroutineMock()
mock_conn_mgr.return_value.receive_request.side_effect = ConnectionManagerError(
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value
)
mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock(
side_effect=ConnectionManagerError("no targets")
Expand All @@ -232,6 +241,10 @@ async def test_problem_report_did_doc_no_conn_target(
result, target = messages[0]
assert (
isinstance(result, ConnectionProblemReport)
and result.problem_code == ProblemReportReason.REQUEST_NOT_ACCEPTED
and result.description
and (
result.description["code"]
== ProblemReportReason.REQUEST_NOT_ACCEPTED.value
)
)
assert target == {"target_list": None}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from aries_cloudagent.tests import mock

from ......connections.models import connection_target
Expand All @@ -10,11 +11,8 @@
)
from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder

from ......protocols.trustping.v1_0.messages.ping import Ping

from ......transport.inbound.receipt import MessageReceipt

from ...handlers import connection_response_handler as handler
from ...manager import ConnectionManagerError
from ...messages.connection_response import ConnectionResponse
Expand Down Expand Up @@ -101,7 +99,7 @@ async def test_called_auto_ping(self, mock_conn_mgr, request_context):
async def test_problem_report(self, mock_conn_mgr, request_context):
mock_conn_mgr.return_value.accept_response = mock.CoroutineMock()
mock_conn_mgr.return_value.accept_response.side_effect = ConnectionManagerError(
error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED
error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED.value,
)
request_context.message = ConnectionResponse()
handler_inst = handler.ConnectionResponseHandler()
Expand All @@ -112,7 +110,11 @@ async def test_problem_report(self, mock_conn_mgr, request_context):
result, target = messages[0]
assert (
isinstance(result, ConnectionProblemReport)
and result.problem_code == ProblemReportReason.RESPONSE_NOT_ACCEPTED
and result.description
and (
result.description["code"]
== ProblemReportReason.RESPONSE_NOT_ACCEPTED.value
)
)
assert target == {"target_list": None}

Expand All @@ -124,7 +126,7 @@ async def test_problem_report_did_doc(
):
mock_conn_mgr.return_value.accept_response = mock.CoroutineMock()
mock_conn_mgr.return_value.accept_response.side_effect = ConnectionManagerError(
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED
error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED.value,
)
mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock(
return_value=[mock_conn_target]
Expand All @@ -140,7 +142,11 @@ async def test_problem_report_did_doc(
result, target = messages[0]
assert (
isinstance(result, ConnectionProblemReport)
and result.problem_code == ProblemReportReason.REQUEST_NOT_ACCEPTED
and result.description
and (
result.description["code"]
== ProblemReportReason.RESPONSE_NOT_ACCEPTED.value
)
)
assert target == {"target_list": [mock_conn_target]}

Expand All @@ -152,7 +158,7 @@ async def test_problem_report_did_doc_no_conn_target(
):
mock_conn_mgr.return_value.accept_response = mock.CoroutineMock()
mock_conn_mgr.return_value.accept_response.side_effect = ConnectionManagerError(
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED
error_code=ProblemReportReason.RESPONSE_NOT_ACCEPTED.value,
)
mock_conn_mgr.return_value.diddoc_connection_targets = mock.MagicMock(
side_effect=ConnectionManagerError("no target")
Expand All @@ -168,6 +174,10 @@ async def test_problem_report_did_doc_no_conn_target(
result, target = messages[0]
assert (
isinstance(result, ConnectionProblemReport)
and result.problem_code == ProblemReportReason.REQUEST_NOT_ACCEPTED
and result.description
and (
result.description["code"]
== ProblemReportReason.RESPONSE_NOT_ACCEPTED.value
)
)
assert target == {"target_list": None}
37 changes: 30 additions & 7 deletions aries_cloudagent/protocols/connections/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import logging
from typing import Optional, Sequence, Tuple, cast


from ....core.oob_processor import OobMessageProcessor
from ....connections.base_manager import BaseConnectionManager
from ....connections.models.conn_record import ConnRecord
from ....core.error import BaseError
from ....core.oob_processor import OobMessageProcessor
from ....core.profile import Profile
from ....messaging.responder import BaseResponder
from ....messaging.valid import IndyDID
Expand All @@ -21,7 +20,7 @@
from .messages.connection_invitation import ConnectionInvitation
from .messages.connection_request import ConnectionRequest
from .messages.connection_response import ConnectionResponse
from .messages.problem_report import ProblemReportReason
from .messages.problem_report import ConnectionProblemReport, ProblemReportReason
from .models.connection_detail import ConnectionDetail


Expand Down Expand Up @@ -261,12 +260,12 @@ async def receive_invitation(
if not invitation.recipient_keys:
raise ConnectionManagerError(
"Invitation must contain recipient key(s)",
error_code="missing-recipient-keys",
error_code=ProblemReportReason.MISSING_RECIPIENT_KEYS.value,
)
if not invitation.endpoint:
raise ConnectionManagerError(
"Invitation must contain an endpoint",
error_code="missing-endpoint",
error_code=ProblemReportReason.MISSING_ENDPOINT.value,
)
accept = (
ConnRecord.ACCEPT_AUTO
Expand Down Expand Up @@ -440,7 +439,8 @@ async def receive_request(
raise ConnectionManagerError(
"No invitation found for pairwise connection "
f"in state {ConnRecord.State.INVITATION.rfc160}: "
"a prior connection request may have updated the connection state"
"a prior connection request may have updated the connection state",
error_code=ProblemReportReason.REQUEST_NOT_ACCEPTED.value,
)

invitation = None
Expand Down Expand Up @@ -489,7 +489,7 @@ async def receive_request(
conn_did_doc = request.connection.did_doc
if not conn_did_doc:
raise ConnectionManagerError(
"No DIDDoc provided; cannot connect to public DID"
"No DIDDoc provided; cannot connect to public DID",
)
if request.connection.did != conn_did_doc.did:
raise ConnectionManagerError(
Expand Down Expand Up @@ -757,3 +757,26 @@ async def accept_response(
await responder.send(request, connection_id=connection.connection_id)

return connection

async def receive_problem_report(
self,
conn_rec: ConnRecord,
report: ConnectionProblemReport,
):
"""Receive problem report."""
if not report.description:
raise ConnectionManagerError("Missing description in problem report")

if report.description.get("code") in {
reason.value for reason in ProblemReportReason
}:
self._logger.info("Problem report indicates connection is abandoned")
async with self.profile.session() as session:
await conn_rec.abandon(
session,
reason=report.description.get("en"),
)
else:
raise ConnectionManagerError(
f"Received unrecognized problem report: {report.description}"
)
Loading