Skip to content

Commit

Permalink
Merge pull request #1970 from shaangill025/mediation_fix
Browse files Browse the repository at this point in the history
Fix: `--mediator-invitation` with OOB invitation + cleanup
  • Loading branch information
swcurran authored Oct 13, 2022
2 parents 9dc2fa3 + 42b5029 commit 960aa91
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 18 deletions.
24 changes: 17 additions & 7 deletions aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..config.logging import LoggingConfigurator
from ..config.provider import ClassProvider
from ..config.wallet import wallet_config
from ..connections.models.conn_record import ConnRecord
from ..core.profile import Profile
from ..indy.verifier import IndyVerifier
from ..ledger.base import BaseLedger
Expand Down Expand Up @@ -450,14 +451,23 @@ async def start(self) -> None:
if mediation_connections_invite
else OutOfBandManager(self.root_profile)
)

conn_record = await mgr.receive_invitation(
invitation=invitation_handler.from_url(
mediation_invite_record.invite
),
auto_accept=True,
)
async with self.root_profile.session() as session:
invitation = invitation_handler.from_url(
mediation_invite_record.invite
)
if isinstance(mgr, OutOfBandManager):
oob_record = await mgr.receive_invitation(
invitation=invitation,
auto_accept=True,
)
conn_record = await ConnRecord.retrieve_by_id(
session, oob_record.connection_id
)
else:
conn_record = await mgr.receive_invitation(
invitation=invitation,
auto_accept=True,
)
await (
MediationInviteStore(
session.context.inject(BaseStorage)
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ async def make_message(
if not isinstance(parsed_msg, dict):
raise MessageParseError("Expected a JSON object")
message_type = parsed_msg.get("@type")
message_type_rec_version = get_version_from_message_type(message_type)

if not message_type:
raise MessageParseError("Message does not contain '@type' parameter")
message_type_rec_version = get_version_from_message_type(message_type)

registry: ProtocolRegistry = self.profile.inject(ProtocolRegistry)
try:
Expand Down
11 changes: 8 additions & 3 deletions aries_cloudagent/core/tests/test_conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,9 @@ async def test_mediator_invitation_0434(self, mock_from_url, _):
"test": async_mock.MagicMock(schemes=["http"])
}
await conductor.setup()

conductor.root_profile.context.update_settings(
{"mediation.connections_invite": False}
)
conn_record = ConnRecord(
invitation_key="3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx",
their_label="Hello",
Expand All @@ -1170,12 +1172,15 @@ async def test_mediator_invitation_0434(self, mock_from_url, _):
)
conn_record.accept = ConnRecord.ACCEPT_MANUAL
await conn_record.save(await conductor.root_profile.session())
oob_record = async_mock.MagicMock(
connection_id=conn_record.connection_id,
)
with async_mock.patch.object(
test_module,
"OutOfBandManager",
async_mock.MagicMock(
return_value=async_mock.MagicMock(
receive_invitation=async_mock.AsyncMock(return_value=conn_record)
receive_invitation=async_mock.AsyncMock(return_value=oob_record)
)
),
) as mock_mgr, async_mock.patch.object(
Expand All @@ -1185,10 +1190,10 @@ async def test_mediator_invitation_0434(self, mock_from_url, _):
return_value=async_mock.MagicMock(value=f"v{__version__}")
),
):
assert not conductor.root_profile.settings["mediation.connections_invite"]
await conductor.start()
await conductor.stop()
mock_from_url.assert_called_once_with("test-invite")
mock_mgr.return_value.receive_invitation.assert_called_once()

@async_mock.patch.object(test_module, "MediationInviteStore")
@async_mock.patch.object(test_module.ConnectionInvitation, "from_url")
Expand Down
17 changes: 11 additions & 6 deletions aries_cloudagent/protocols/connections/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
request_schema,
response_schema,
)

from typing import cast
from marshmallow import fields, validate, validates_schema

from ....admin.request_context import AdminRequestContext
Expand Down Expand Up @@ -115,7 +115,7 @@ class CreateInvitationRequestSchema(OpenAPISchema):
mediation_id = fields.Str(
required=False,
description="Identifier for active mediation record to be used",
**MEDIATION_ID_SCHEMA
**MEDIATION_ID_SCHEMA,
)


Expand Down Expand Up @@ -247,7 +247,7 @@ class ReceiveInvitationQueryStringSchema(OpenAPISchema):
mediation_id = fields.Str(
required=False,
description="Identifier for active mediation record to be used",
**MEDIATION_ID_SCHEMA
**MEDIATION_ID_SCHEMA,
)


Expand All @@ -261,7 +261,7 @@ class AcceptInvitationQueryStringSchema(OpenAPISchema):
mediation_id = fields.Str(
required=False,
description="Identifier for active mediation record to be used",
**MEDIATION_ID_SCHEMA
**MEDIATION_ID_SCHEMA,
)


Expand Down Expand Up @@ -536,11 +536,16 @@ async def connections_create_invitation(request: web.BaseRequest):
metadata=metadata,
mediation_id=mediation_id,
)

invitation_url = invitation.to_url(base_url)
base_endpoint = service_endpoint or cast(
str, profile.settings.get("default_endpoint")
)
result = {
"connection_id": connection and connection.connection_id,
"invitation": invitation.serialize(),
"invitation_url": invitation.to_url(base_url),
"invitation_url": f"{base_endpoint}{invitation_url}"
if invitation_url.startswith("?")
else invitation_url,
}
except (ConnectionManagerError, StorageError, BaseModelError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,11 +764,15 @@ async def on_startup_event(profile: Profile, event: Event):
invite = InvitationMessage.from_url(endorser_invitation)
if invite:
oob_mgr = OutOfBandManager(profile)
conn_record = await oob_mgr.receive_invitation(
oob_record = await oob_mgr.receive_invitation(
invitation=invite,
auto_accept=True,
alias=endorser_alias,
)
async with profile.session() as session:
conn_record = await ConnRecord.retrieve_by_id(
session, oob_record.connection_id
)
else:
invite = ConnectionInvitation.from_url(endorser_invitation)
if invite:
Expand Down

0 comments on commit 960aa91

Please sign in to comment.