diff --git a/aries_cloudagent/transport/inbound/http.py b/aries_cloudagent/transport/inbound/http.py index 8179023379..696cab4bd3 100644 --- a/aries_cloudagent/transport/inbound/http.py +++ b/aries_cloudagent/transport/inbound/http.py @@ -96,7 +96,12 @@ async def inbound_message_handler(self, request: web.BaseRequest): raise web.HTTPBadRequest() if inbound.receipt.direct_response_requested: - response = await session.wait_response() + # Wait for the message to be processed. Only send a response if a response + # buffer is present. + await inbound.wait_processing_complete() + response = ( + await session.wait_response() if session.response_buffer else None + ) # no more responses session.can_respond = False diff --git a/aries_cloudagent/transport/inbound/manager.py b/aries_cloudagent/transport/inbound/manager.py index 65d081454d..909703e7c5 100644 --- a/aries_cloudagent/transport/inbound/manager.py +++ b/aries_cloudagent/transport/inbound/manager.py @@ -181,6 +181,8 @@ def dispatch_complete(self, message: InboundMessage, completed: CompletedTask): if session and session.accept_undelivered and not session.response_buffered: self.process_undelivered(session) + message.dispatch_processing_complete() + def closed_session(self, session: InboundSession): """ Clean up a closed session. diff --git a/aries_cloudagent/transport/inbound/message.py b/aries_cloudagent/transport/inbound/message.py index 169b2dc35c..2def0e1d92 100644 --- a/aries_cloudagent/transport/inbound/message.py +++ b/aries_cloudagent/transport/inbound/message.py @@ -1,5 +1,6 @@ """Classes representing inbound messages.""" +import asyncio from typing import Union from .receipt import MessageReceipt @@ -23,3 +24,12 @@ def __init__( self.receipt = receipt self.session_id = session_id self.transport_type = transport_type + self.processing_complete_event = asyncio.Event() + + def dispatch_processing_complete(self): + """Dispatch processing complete.""" + self.processing_complete_event.set() + + async def wait_processing_complete(self): + """Wait for processing to complete.""" + await self.processing_complete_event.wait() diff --git a/aries_cloudagent/transport/inbound/tests/test_http_transport.py b/aries_cloudagent/transport/inbound/tests/test_http_transport.py index c440fe7a57..fd74775572 100644 --- a/aries_cloudagent/transport/inbound/tests/test_http_transport.py +++ b/aries_cloudagent/transport/inbound/tests/test_http_transport.py @@ -62,6 +62,7 @@ def receive_message( message: InboundMessage, can_respond: bool = False, ): + message.wait_processing_complete = async_mock.CoroutineMock() self.message_results.append((message.payload, message.receipt, can_respond)) if self.result_event: self.result_event.set() @@ -119,13 +120,15 @@ async def test_send_message_outliers(self): mock_session.return_value = async_mock.MagicMock( receive=async_mock.CoroutineMock( return_value=async_mock.MagicMock( - receipt=async_mock.MagicMock(direct_response_requested=True) + receipt=async_mock.MagicMock(direct_response_requested=True), + wait_processing_complete=async_mock.CoroutineMock(), ) ), can_respond=True, profile=InMemoryProfile.test_profile(), clear_response=async_mock.MagicMock(), wait_response=async_mock.CoroutineMock(return_value=b"Hello world"), + response_buffer="something", ) async with self.client.post("/", data=test_message) as resp: result = await resp.text() diff --git a/aries_cloudagent/transport/inbound/tests/test_message.py b/aries_cloudagent/transport/inbound/tests/test_message.py new file mode 100644 index 0000000000..71a8defee8 --- /dev/null +++ b/aries_cloudagent/transport/inbound/tests/test_message.py @@ -0,0 +1,30 @@ +import asyncio + +from asynctest import TestCase + +from ..message import InboundMessage +from ..receipt import MessageReceipt + + +class TestInboundMessage(TestCase): + async def test_wait_response(self): + message = InboundMessage( + payload="test", + connection_id="conn_id", + receipt=MessageReceipt(), + session_id="session_id", + ) + assert not message.processing_complete_event.is_set() + message.dispatch_processing_complete() + assert message.processing_complete_event.is_set() + + message = InboundMessage( + payload="test", + connection_id="conn_id", + receipt=MessageReceipt(), + session_id="session_id", + ) + assert not message.processing_complete_event.is_set() + task = message.wait_processing_complete() + message.dispatch_processing_complete() + await asyncio.wait_for(task, 1)