Skip to content

Commit

Permalink
Fix request while protocol context is shutting down (#814)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinHjelmare authored Feb 3, 2024
1 parent 9349dd7 commit 948dbcc
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 27 deletions.
10 changes: 8 additions & 2 deletions pytradfri/api/aiocoap_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,14 @@ def psk(self) -> str | None:
"""Return psk."""
return self._psk

async def _get_protocol(self) -> Context:
async def _get_protocol(self, check_reset_lock: bool = True) -> Context:
"""Get the protocol for the request."""
if check_reset_lock and self._reset_lock.locked():
# If the reset lock is held, it means that the protocol is being reset.
# We are not allowed to make a request with a protocol that is shut down.
# We need to wait for the lock in that case.
async with self._reset_lock:
return await self._get_protocol(check_reset_lock=False)
if self._protocol is None:
self._protocol = asyncio.create_task(Context.create_client_context())
return await self._protocol
Expand All @@ -122,7 +128,7 @@ async def _reset_protocol(self, exc: BaseException | None = None) -> None:
_LOGGER.debug("Resetting protocol")

# Be responsible and clean up.
protocol = await self._get_protocol()
protocol = await self._get_protocol(check_reset_lock=False)
await protocol.shutdown()
self._protocol = None

Expand Down
104 changes: 79 additions & 25 deletions tests/api/test_aiocoap_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Test aiocoap API."""

import asyncio
from collections.abc import Awaitable, Callable
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

from aiocoap import Message
from aiocoap import Context
from aiocoap.credentials import CredentialsMap
from aiocoap.error import Error
import pytest

from pytradfri.api.aiocoap_api import APIFactory
from pytradfri.command import Command
from pytradfri.error import ServerError


class MockCode:
Expand All @@ -31,41 +37,57 @@ def payload(self) -> bytes:
return b'{"one": 1}'


class MockProtocol:
class MockRequest:
"""Mock Protocol."""

async def mock_response(self) -> Any:
"""Return protocol response."""
return MockResponse()
def __init__(self, response: Callable[[], Awaitable]) -> None:
"""Create the request."""
self._response = response

@property
def response(self) -> Any:
"""Return protocol response."""
return self.mock_response()


class MockContext:
"""Mock a context."""

def request(self, message: Message) -> MockProtocol:
"""Request a protocol."""
return MockProtocol()


async def mock_create_context() -> MockContext:
"""Return a context."""
return MockContext()
return self._response()


def process_result(result: Any) -> Any:
"""Process result."""
return result


async def test_request_returns_single(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture(name="response")
def response_fixture() -> AsyncMock:
"""Mock response."""
return AsyncMock()


@pytest.fixture(name="context")
def context_fixture(response) -> MagicMock:
"""Mock context."""
with patch(
"pytradfri.api.aiocoap_api.Context.create_client_context"
) as create_client_context:
context = MagicMock(spec=Context)

async def create_context() -> MagicMock:
"""Reset the context."""
response.return_value = MockResponse()
response.side_effect = None
context.serversite = None
context.request_interfaces = []
context.client_credentials = CredentialsMap()
context.server_credentials = CredentialsMap()
context.request.return_value = MockRequest(response=response)
context.shutdown.side_effect = None
return context

create_client_context.side_effect = create_context
context.create_client_context = create_client_context
yield context


async def test_request_returns_single(context: MagicMock) -> None:
"""Test return single object."""
monkeypatch.setattr("aiocoap.Context.create_client_context", mock_create_context)

api = (await APIFactory.init("127.0.0.1")).request

command: Command[dict[str, int]] = Command("", [""], process_result=process_result)
Expand All @@ -75,10 +97,8 @@ async def test_request_returns_single(monkeypatch: pytest.MonkeyPatch) -> None:
assert response == {"one": 1}


async def test_request_returns_list(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_request_returns_list(context: MagicMock) -> None:
"""Test return of lists."""
monkeypatch.setattr("aiocoap.Context.create_client_context", mock_create_context)

api = (await APIFactory.init("127.0.0.1")).request

command: Command[dict[str, int]] = Command("", [""], process_result=process_result)
Expand All @@ -87,3 +107,37 @@ async def test_request_returns_list(monkeypatch: pytest.MonkeyPatch) -> None:

assert isinstance(response, list)
assert response == [{"one": 1}, {"one": 1}, {"one": 1}]


async def test_context_shutdown_request(
context: MagicMock, response: AsyncMock
) -> None:
"""Test a request while context is shutting down."""
factory = await APIFactory.init("127.0.0.1", psk="test-psk")
shutdown_event = asyncio.Event()

async def mock_shutdown() -> None:
"""Mock shutdown."""
response.side_effect = Exception("Context was shutdown")
await shutdown_event.wait()

context.shutdown.side_effect = mock_shutdown
response.side_effect = Error("Boom!")

request_task_1 = asyncio.create_task(
factory.request(Command("", [""], process_result=process_result))
)
await asyncio.sleep(0)
request_task_2 = asyncio.create_task(
factory.request(Command("", [""], process_result=process_result))
)
await asyncio.sleep(0)

shutdown_event.set()
with pytest.raises(ServerError):
await request_task_1
result = await request_task_2

assert context.shutdown.call_count == 1
assert context.request.call_count == 2
assert result == {"one": 1}

0 comments on commit 948dbcc

Please sign in to comment.