Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: return if return route but no response #1853

7 changes: 6 additions & 1 deletion aries_cloudagent/transport/inbound/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ async def inbound_message_handler(self, request: web.BaseRequest):
raise web.HTTPBadRequest()

if inbound.receipt.direct_response_requested:
response = await session.wait_response()
# Wait for the message to be processed. Only send a response if a response
# buffer is present.
await inbound.wait_processing_complete()
response = (
await session.wait_response() if session.response_buffer else None
)

# no more responses
session.can_respond = False
Expand Down
2 changes: 2 additions & 0 deletions aries_cloudagent/transport/inbound/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def dispatch_complete(self, message: InboundMessage, completed: CompletedTask):
if session and session.accept_undelivered and not session.response_buffered:
self.process_undelivered(session)

message.dispatch_processing_complete()

def closed_session(self, session: InboundSession):
"""
Clean up a closed session.
Expand Down
10 changes: 10 additions & 0 deletions aries_cloudagent/transport/inbound/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes representing inbound messages."""

import asyncio
from typing import Union

from .receipt import MessageReceipt
Expand All @@ -23,3 +24,12 @@ def __init__(
self.receipt = receipt
self.session_id = session_id
self.transport_type = transport_type
self.processing_complete_event = asyncio.Event()

def dispatch_processing_complete(self):
"""Dispatch processing complete."""
self.processing_complete_event.set()

async def wait_processing_complete(self):
"""Wait for processing to complete."""
await self.processing_complete_event.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def receive_message(
message: InboundMessage,
can_respond: bool = False,
):
message.wait_processing_complete = async_mock.CoroutineMock()
self.message_results.append((message.payload, message.receipt, can_respond))
if self.result_event:
self.result_event.set()
Expand Down Expand Up @@ -119,13 +120,15 @@ async def test_send_message_outliers(self):
mock_session.return_value = async_mock.MagicMock(
receive=async_mock.CoroutineMock(
return_value=async_mock.MagicMock(
receipt=async_mock.MagicMock(direct_response_requested=True)
receipt=async_mock.MagicMock(direct_response_requested=True),
wait_processing_complete=async_mock.CoroutineMock(),
)
),
can_respond=True,
profile=InMemoryProfile.test_profile(),
clear_response=async_mock.MagicMock(),
wait_response=async_mock.CoroutineMock(return_value=b"Hello world"),
response_buffer="something",
)
async with self.client.post("/", data=test_message) as resp:
result = await resp.text()
Expand Down
30 changes: 30 additions & 0 deletions aries_cloudagent/transport/inbound/tests/test_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio

from asynctest import TestCase

from ..message import InboundMessage
from ..receipt import MessageReceipt


class TestInboundMessage(TestCase):
async def test_wait_response(self):
message = InboundMessage(
payload="test",
connection_id="conn_id",
receipt=MessageReceipt(),
session_id="session_id",
)
assert not message.processing_complete_event.is_set()
message.dispatch_processing_complete()
assert message.processing_complete_event.is_set()

message = InboundMessage(
payload="test",
connection_id="conn_id",
receipt=MessageReceipt(),
session_id="session_id",
)
assert not message.processing_complete_event.is_set()
task = message.wait_processing_complete()
message.dispatch_processing_complete()
await asyncio.wait_for(task, 1)