Skip to content

Commit

Permalink
Merge pull request #1853 from TimoGlastra/fix/return-processing-no-re…
Browse files Browse the repository at this point in the history
…sponse

fix: return if return route but no response
  • Loading branch information
swcurran authored Nov 8, 2022
2 parents 22ca606 + bc084cb commit fe08978
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 2 deletions.
7 changes: 6 additions & 1 deletion aries_cloudagent/transport/inbound/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ async def inbound_message_handler(self, request: web.BaseRequest):
raise web.HTTPBadRequest()

if inbound.receipt.direct_response_requested:
response = await session.wait_response()
# Wait for the message to be processed. Only send a response if a response
# buffer is present.
await inbound.wait_processing_complete()
response = (
await session.wait_response() if session.response_buffer else None
)

# no more responses
session.can_respond = False
Expand Down
2 changes: 2 additions & 0 deletions aries_cloudagent/transport/inbound/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def dispatch_complete(self, message: InboundMessage, completed: CompletedTask):
if session and session.accept_undelivered and not session.response_buffered:
self.process_undelivered(session)

message.dispatch_processing_complete()

def closed_session(self, session: InboundSession):
"""
Clean up a closed session.
Expand Down
10 changes: 10 additions & 0 deletions aries_cloudagent/transport/inbound/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes representing inbound messages."""

import asyncio
from typing import Union

from .receipt import MessageReceipt
Expand All @@ -23,3 +24,12 @@ def __init__(
self.receipt = receipt
self.session_id = session_id
self.transport_type = transport_type
self.processing_complete_event = asyncio.Event()

def dispatch_processing_complete(self):
"""Dispatch processing complete."""
self.processing_complete_event.set()

async def wait_processing_complete(self):
"""Wait for processing to complete."""
await self.processing_complete_event.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def receive_message(
message: InboundMessage,
can_respond: bool = False,
):
message.wait_processing_complete = async_mock.CoroutineMock()
self.message_results.append((message.payload, message.receipt, can_respond))
if self.result_event:
self.result_event.set()
Expand Down Expand Up @@ -119,13 +120,15 @@ async def test_send_message_outliers(self):
mock_session.return_value = async_mock.MagicMock(
receive=async_mock.CoroutineMock(
return_value=async_mock.MagicMock(
receipt=async_mock.MagicMock(direct_response_requested=True)
receipt=async_mock.MagicMock(direct_response_requested=True),
wait_processing_complete=async_mock.CoroutineMock(),
)
),
can_respond=True,
profile=InMemoryProfile.test_profile(),
clear_response=async_mock.MagicMock(),
wait_response=async_mock.CoroutineMock(return_value=b"Hello world"),
response_buffer="something",
)
async with self.client.post("/", data=test_message) as resp:
result = await resp.text()
Expand Down
30 changes: 30 additions & 0 deletions aries_cloudagent/transport/inbound/tests/test_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio

from asynctest import TestCase

from ..message import InboundMessage
from ..receipt import MessageReceipt


class TestInboundMessage(TestCase):
async def test_wait_response(self):
message = InboundMessage(
payload="test",
connection_id="conn_id",
receipt=MessageReceipt(),
session_id="session_id",
)
assert not message.processing_complete_event.is_set()
message.dispatch_processing_complete()
assert message.processing_complete_event.is_set()

message = InboundMessage(
payload="test",
connection_id="conn_id",
receipt=MessageReceipt(),
session_id="session_id",
)
assert not message.processing_complete_event.is_set()
task = message.wait_processing_complete()
message.dispatch_processing_complete()
await asyncio.wait_for(task, 1)

0 comments on commit fe08978

Please sign in to comment.