diff --git a/robot-server/robot_server/runs/dependencies.py b/robot-server/robot_server/runs/dependencies.py index 20b8d087b66d..f66ec9fdf1cd 100644 --- a/robot-server/robot_server/runs/dependencies.py +++ b/robot-server/robot_server/runs/dependencies.py @@ -43,13 +43,12 @@ async def get_run_store( app_state: AppState = Depends(get_app_state), sql_engine: SQLEngine = Depends(get_sql_engine), - runs_publisher: RunsPublisher = Depends(get_runs_publisher), ) -> RunStore: """Get a singleton RunStore to keep track of created runs.""" run_store = _run_store_accessor.get_from(app_state) if run_store is None: - run_store = RunStore(sql_engine=sql_engine, runs_publisher=runs_publisher) + run_store = RunStore(sql_engine=sql_engine) _run_store_accessor.set_on(app_state, run_store) return run_store diff --git a/robot-server/robot_server/runs/run_data_manager.py b/robot-server/robot_server/runs/run_data_manager.py index f0fc28dca371..40201c1f3e46 100644 --- a/robot-server/robot_server/runs/run_data_manager.py +++ b/robot-server/robot_server/runs/run_data_manager.py @@ -180,7 +180,7 @@ async def create( created_at=created_at, protocol_id=protocol.protocol_id if protocol is not None else None, ) - await self._runs_publisher.begin_polling_engine_store( + await self._runs_publisher.initialize( get_current_command=self.get_current_command, get_state_summary=self._get_good_state_summary, run_id=run_id, @@ -271,7 +271,7 @@ async def delete(self, run_id: str) -> None: """ if run_id == self._engine_store.current_run_id: await self._engine_store.clear() - await self._runs_publisher.stop_polling_engine_store() + await self._runs_publisher.clean_up_current_run() self._run_store.remove(run_id=run_id) diff --git a/robot-server/robot_server/runs/run_store.py b/robot-server/robot_server/runs/run_store.py index 6178e180470e..5aa6dbae96bb 100644 --- a/robot-server/robot_server/runs/run_store.py +++ b/robot-server/robot_server/runs/run_store.py @@ -27,7 +27,6 @@ ) from robot_server.persistence.pydantic import json_to_pydantic, pydantic_to_json from robot_server.protocols.protocol_store import ProtocolNotFoundError -from robot_server.service.notifications import RunsPublisher from .action_models import RunAction, RunActionType from .run_models import RunNotFoundError @@ -94,11 +93,9 @@ class RunStore: def __init__( self, sql_engine: sqlalchemy.engine.Engine, - runs_publisher: RunsPublisher, ) -> None: """Initialize a RunStore with sql engine and notification client.""" self._sql_engine = sql_engine - self._runs_publisher = runs_publisher def update_run_state( self, @@ -166,7 +163,6 @@ def update_run_state( action_rows = transaction.execute(select_actions).all() self._clear_caches() - self._runs_publisher.publish_runs_advise_refetch(run_id=run_id) maybe_run_resource = _convert_row_to_run(row=run_row, action_rows=action_rows) if not maybe_run_resource.ok: raise maybe_run_resource.error @@ -192,7 +188,6 @@ def insert_action(self, run_id: str, action: RunAction) -> None: transaction.execute(insert) self._clear_caches() - self._runs_publisher.publish_runs_advise_refetch(run_id=run_id) def insert( self, @@ -235,7 +230,6 @@ def insert( raise ProtocolNotFoundError(protocol_id=run.protocol_id) self._clear_caches() - self._runs_publisher.publish_runs_advise_refetch(run_id=run_id) return run @lru_cache(maxsize=_CACHE_ENTRIES) @@ -467,7 +461,6 @@ def remove(self, run_id: str) -> None: raise RunNotFoundError(run_id) self._clear_caches() - self._runs_publisher.publish_runs_advise_unsubscribe(run_id=run_id) def _run_exists( self, run_id: str, connection: sqlalchemy.engine.Connection diff --git a/robot-server/robot_server/service/notifications/__init__.py b/robot-server/robot_server/service/notifications/__init__.py index 7a71a61298d7..7fd648f32aa6 100644 --- a/robot-server/robot_server/service/notifications/__init__.py +++ b/robot-server/robot_server/service/notifications/__init__.py @@ -14,6 +14,7 @@ get_runs_publisher, ) from .change_notifier import ChangeNotifier +from .topics import Topics __all__ = [ # main export @@ -32,4 +33,5 @@ # for testing "PublisherNotifier", "ChangeNotifier", + "Topics", ] diff --git a/robot-server/robot_server/service/notifications/publishers/runs_publisher.py b/robot-server/robot_server/service/notifications/publishers/runs_publisher.py index 94aed694e8f1..1ed688e6a744 100644 --- a/robot-server/robot_server/service/notifications/publishers/runs_publisher.py +++ b/robot-server/robot_server/service/notifications/publishers/runs_publisher.py @@ -1,7 +1,6 @@ -from fastapi import Depends import asyncio -import logging -from typing import Union, Callable, Optional +from fastapi import Depends +from typing import Callable, Optional from opentrons.protocol_engine import CurrentCommand, StateSummary, EngineStatus @@ -11,173 +10,104 @@ get_app_state, ) from ..notification_client import NotificationClient, get_notification_client +from ..publisher_notifier import PublisherNotifier, get_publisher_notifier from ..topics import Topics -log: logging.Logger = logging.getLogger(__name__) - -POLL_INTERVAL = 1 - - class RunsPublisher: """Publishes protocol runs topics.""" - def __init__(self, client: NotificationClient) -> None: + def __init__( + self, client: NotificationClient, publisher_notifier: PublisherNotifier + ) -> None: """Returns a configured Runs Publisher.""" self._client = client + self._publisher_notifier = publisher_notifier self._run_data_manager_polling = asyncio.Event() - self._previous_current_command: Union[CurrentCommand, None] = None - self._previous_state_summary_status: Union[EngineStatus, None] = None self._poller: Optional[asyncio.Task[None]] = None + # Variables and callbacks related to PE state changes. + self._run_id: Optional[str] = None + self._get_current_command: Optional[ + Callable[[str], Optional[CurrentCommand]] + ] = None + self._get_state_summary: Optional[ + Callable[[str], Optional[StateSummary]] + ] = None + self._previous_current_command: Optional[CurrentCommand] = None + self._previous_state_summary_status: Optional[EngineStatus] = None + + self._publisher_notifier.register_publish_callbacks( + [self._handle_current_command_change, self._handle_engine_status_change] + ) - # TODO(jh, 2023-02-02): Instead of polling, emit current_commands directly from PE. - async def begin_polling_engine_store( + async def initialize( self, - get_current_command: Callable[[str], Optional[CurrentCommand]], - get_state_summary: Callable[[str], Optional[StateSummary]], run_id: str, - ) -> None: - """Continuously poll the engine store for the current_command. - - Args: - get_current_command: Callback to get the currently executing command, if any. - get_state_summary: Callback to get the current run's state summary, if any. - run_id: ID of the current run. - """ - if self._poller is None: - self._poller = asyncio.create_task( - self._poll_engine_store( - get_current_command=get_current_command, - run_id=run_id, - get_state_summary=get_state_summary, - ) - ) - else: - await self.stop_polling_engine_store() - self._poller = asyncio.create_task( - self._poll_engine_store( - get_current_command=get_current_command, - run_id=run_id, - get_state_summary=get_state_summary, - ) - ) - - async def stop_polling_engine_store(self) -> None: - """Stops polling the engine store. Run-related topics will publish as the poller is cancelled.""" - if self._poller is not None: - self._run_data_manager_polling.set() - self._poller.cancel() - - def publish_runs_advise_refetch(self, run_id: str) -> None: - """Publishes the equivalent of GET /runs and GET /runs/:runId. - - Args: - run_id: ID of the current run. - """ - self._client.publish_advise_refetch(topic=Topics.RUNS) - self._client.publish_advise_refetch(topic=f"{Topics.RUNS}/{run_id}") - - def publish_runs_advise_unsubscribe(self, run_id: str) -> None: - """Publishes the equivalent of GET /runs and GET /runs/:runId. - - Args: - run_id: ID of the current run. - """ - self._client.publish_advise_unsubscribe(topic=Topics.RUNS) - self._client.publish_advise_unsubscribe(topic=f"{Topics.RUNS}/{run_id}") - - async def _poll_engine_store( - self, get_current_command: Callable[[str], Optional[CurrentCommand]], get_state_summary: Callable[[str], Optional[StateSummary]], - run_id: str, ) -> None: - """Asynchronously publish new current commands. - - Args: - get_current_command: Retrieves the engine store's current command. - get_state_summary: Retrieves the engine store's state summary. - run_id: ID of the current run. - """ - try: - await self._poll_for_run_id_info( - get_current_command=get_current_command, - get_state_summary=get_state_summary, - run_id=run_id, - ) - except asyncio.CancelledError: - self._clean_up_poller() - await self._publish_runs_advise_unsubscribe_async(run_id=run_id) - await self._client.publish_advise_refetch_async( - topic=Topics.RUNS_CURRENT_COMMAND - ) - except Exception as e: - log.error(f"Error within run data manager poller: {e}") - - async def _poll_for_run_id_info( - self, - get_current_command: Callable[[str], Optional[CurrentCommand]], - get_state_summary: Callable[[str], Optional[StateSummary]], - run_id: str, - ): - """Poll the engine store for a specific run's state while the poll is active. + """Initialize RunsPublisher with necessary information derived from the current run. Args: - get_current_command: Retrieves the engine store's current command. - get_state_summary: Retrieves the engine store's state summary. run_id: ID of the current run. + get_current_command: Callback to get the currently executing command, if any. + get_state_summary: Callback to get the current run's state summary, if any. """ + self._run_id = run_id + self._get_current_command = get_current_command + self._get_state_summary = get_state_summary + self._previous_current_command = None + self._previous_state_summary_status = None - while not self._run_data_manager_polling.is_set(): - current_command = get_current_command(run_id) - current_state_summary = get_state_summary(run_id) - current_state_summary_status = ( - current_state_summary.status if current_state_summary else None - ) - - if self._previous_current_command != current_command: - await self._publish_current_command() - self._previous_current_command = current_command + await self._publish_runs_advise_refetch_async() - if self._previous_state_summary_status != current_state_summary_status: - await self._publish_runs_advise_refetch_async(run_id=run_id) - self._previous_state_summary_status = current_state_summary_status - await asyncio.sleep(POLL_INTERVAL) + async def clean_up_current_run(self) -> None: + """Publish final refetch and unsubscribe flags.""" + await self._publish_runs_advise_refetch_async() + await self._publish_runs_advise_unsubscribe_async() - async def _publish_current_command( - self, - ) -> None: + async def _publish_current_command(self) -> None: """Publishes the equivalent of GET /runs/:runId/commands?cursor=null&pageLength=1.""" await self._client.publish_advise_refetch_async( topic=Topics.RUNS_CURRENT_COMMAND ) - async def _publish_runs_advise_refetch_async(self, run_id: str) -> None: - """Asynchronously publishes the equivalent of GET /runs and GET /runs/:runId via a refetch message. - - Args: - run_id: ID of the current run. - """ + async def _publish_runs_advise_refetch_async(self) -> None: + """Publish a refetch flag for relevant runs topics.""" await self._client.publish_advise_refetch_async(topic=Topics.RUNS) - await self._client.publish_advise_refetch_async(topic=f"{Topics.RUNS}/{run_id}") - - async def _publish_runs_advise_unsubscribe_async(self, run_id: str) -> None: - """Asynchronously publishes the equivalent of GET /runs and GET /runs/:runId via an unsubscribe message. + await self._client.publish_advise_refetch_async( + topic=f"{Topics.RUNS}/{self._run_id}" + ) - Args: - run_id: ID of the current run. - """ - await self._client.publish_advise_unsubscribe_async(topic=Topics.RUNS) + async def _publish_runs_advise_unsubscribe_async(self) -> None: + """Publish an unsubscribe flag for relevant runs topics.""" await self._client.publish_advise_unsubscribe_async( - topic=f"{Topics.RUNS}/{run_id}" + topic=f"{Topics.RUNS}/{self._run_id}" ) - def _clean_up_poller(self) -> None: - """Cleans up the runs data manager poller.""" - self._poller = None - self._run_data_manager_polling.clear() - self._previous_current_command = None - self._previous_state_summary_status = None + async def _handle_current_command_change(self) -> None: + """Publish a refetch flag if the current command has changed.""" + assert self._get_current_command is not None + assert self._run_id is not None + + current_command = self._get_current_command(self._run_id) + if self._previous_current_command != current_command: + await self._publish_current_command() + self._previous_current_command = current_command + + async def _handle_engine_status_change(self) -> None: + """Publish a refetch flag if the engine status has changed.""" + assert self._get_state_summary is not None + assert self._run_id is not None + + current_state_summary = self._get_state_summary(self._run_id) + + if ( + current_state_summary is not None + and self._previous_state_summary_status != current_state_summary.status + ): + await self._publish_runs_advise_refetch_async() + self._previous_state_summary_status = current_state_summary.status _runs_publisher_accessor: AppStateAccessor[RunsPublisher] = AppStateAccessor[ @@ -188,12 +118,15 @@ def _clean_up_poller(self) -> None: async def get_runs_publisher( app_state: AppState = Depends(get_app_state), notification_client: NotificationClient = Depends(get_notification_client), + publisher_notifier: PublisherNotifier = Depends(get_publisher_notifier), ) -> RunsPublisher: """Get a singleton RunsPublisher to publish runs topics.""" runs_publisher = _runs_publisher_accessor.get_from(app_state) if runs_publisher is None: - runs_publisher = RunsPublisher(client=notification_client) + runs_publisher = RunsPublisher( + client=notification_client, publisher_notifier=publisher_notifier + ) _runs_publisher_accessor.set_on(app_state, runs_publisher) return runs_publisher diff --git a/robot-server/tests/protocols/test_protocol_store.py b/robot-server/tests/protocols/test_protocol_store.py index bd6655e4c108..d75212fd2fe3 100644 --- a/robot-server/tests/protocols/test_protocol_store.py +++ b/robot-server/tests/protocols/test_protocol_store.py @@ -50,7 +50,7 @@ def mock_runs_publisher(decoy: Decoy) -> RunsPublisher: @pytest.fixture def run_store(sql_engine: SQLEngine, mock_runs_publisher: RunsPublisher) -> RunStore: """Get a RunStore linked to the same database as the subject ProtocolStore.""" - return RunStore(sql_engine=sql_engine, runs_publisher=mock_runs_publisher) + return RunStore(sql_engine=sql_engine) async def test_insert_and_get_protocol( diff --git a/robot-server/tests/runs/test_run_store.py b/robot-server/tests/runs/test_run_store.py index bb089d4b40a6..31cabbe56bd2 100644 --- a/robot-server/tests/runs/test_run_store.py +++ b/robot-server/tests/runs/test_run_store.py @@ -47,7 +47,6 @@ def subject( """Get a ProtocolStore test subject.""" return RunStore( sql_engine=sql_engine, - runs_publisher=mock_runs_publisher, ) diff --git a/robot-server/tests/service/notifications/publishers/__init__.py b/robot-server/tests/service/notifications/publishers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/robot-server/tests/service/notifications/publishers/test_maintenance_runs_publisher.py b/robot-server/tests/service/notifications/publishers/test_maintenance_runs_publisher.py new file mode 100644 index 000000000000..8a0cb6a18329 --- /dev/null +++ b/robot-server/tests/service/notifications/publishers/test_maintenance_runs_publisher.py @@ -0,0 +1,30 @@ +"""Tests for the maintenance runs publisher.""" +import pytest +from unittest.mock import AsyncMock + +from robot_server.service.notifications import MaintenanceRunsPublisher, Topics + + +@pytest.fixture +def notification_client() -> AsyncMock: + """Mocked notification client.""" + return AsyncMock() + + +@pytest.fixture +def maintenance_runs_publisher( + notification_client: AsyncMock, +) -> MaintenanceRunsPublisher: + """Instantiate MaintenanceRunsPublisher.""" + return MaintenanceRunsPublisher(notification_client) + + +@pytest.mark.asyncio +async def test_publish_current_maintenance_run( + notification_client: AsyncMock, maintenance_runs_publisher: MaintenanceRunsPublisher +) -> None: + """It should publish a notify flag for maintenance runs.""" + await maintenance_runs_publisher.publish_current_maintenance_run() + notification_client.publish_advise_refetch_async.assert_awaited_once_with( + topic=Topics.MAINTENANCE_RUNS_CURRENT_RUN + ) diff --git a/robot-server/tests/service/notifications/publishers/test_runs_publisher.py b/robot-server/tests/service/notifications/publishers/test_runs_publisher.py new file mode 100644 index 000000000000..ff182d2362e4 --- /dev/null +++ b/robot-server/tests/service/notifications/publishers/test_runs_publisher.py @@ -0,0 +1,129 @@ +"""Tests for runs publisher.""" +import pytest +from datetime import datetime +from unittest.mock import MagicMock, AsyncMock + +from robot_server.service.notifications import RunsPublisher, Topics +from opentrons.protocol_engine import CurrentCommand, EngineStatus + + +def mock_curent_command(command_id: str) -> CurrentCommand: + """Create a mock CurrentCommand.""" + return CurrentCommand( + command_id=command_id, + command_key="1", + index=0, + created_at=datetime(year=2021, month=1, day=1), + ) + + +@pytest.fixture +def notification_client() -> AsyncMock: + """Mocked notification client.""" + return AsyncMock() + + +@pytest.fixture +def publisher_notifier() -> AsyncMock: + """Mocked publisher notifier.""" + return AsyncMock() + + +@pytest.fixture +def runs_publisher( + notification_client: AsyncMock, publisher_notifier: AsyncMock +) -> RunsPublisher: + """Instantiate RunsPublisher.""" + return RunsPublisher( + client=notification_client, publisher_notifier=publisher_notifier + ) + + +@pytest.mark.asyncio +async def test_initialize( + runs_publisher: RunsPublisher, notification_client: AsyncMock +) -> None: + """It should initialize the runs_publisher with required parameters and callbacks.""" + run_id = "1234" + get_current_command = AsyncMock() + get_state_summary = AsyncMock() + + await runs_publisher.initialize(run_id, get_current_command, get_state_summary) + + assert runs_publisher._run_id == run_id + assert runs_publisher._get_current_command == get_current_command + assert runs_publisher._get_state_summary == get_state_summary + assert runs_publisher._previous_current_command is None + assert runs_publisher._previous_state_summary_status is None + + notification_client.publish_advise_refetch_async.assert_any_await(topic=Topics.RUNS) + notification_client.publish_advise_refetch_async.assert_any_await( + topic=f"{Topics.RUNS}/1234" + ) + + +@pytest.mark.asyncio +async def test_clean_up_current_run( + runs_publisher: RunsPublisher, notification_client: AsyncMock +) -> None: + """It should publish to appropriate topics at the end of a run.""" + runs_publisher._run_id = "1234" + + await runs_publisher.clean_up_current_run() + + notification_client.publish_advise_refetch_async.assert_any_await(topic=Topics.RUNS) + notification_client.publish_advise_refetch_async.assert_any_await( + topic=f"{Topics.RUNS}/1234" + ) + notification_client.publish_advise_unsubscribe_async.assert_any_await( + topic=f"{Topics.RUNS}/1234" + ) + + +@pytest.mark.asyncio +async def test_handle_current_command_change( + runs_publisher: RunsPublisher, notification_client: AsyncMock +) -> None: + """It should handle command changes appropriately.""" + runs_publisher._run_id = "1234" + runs_publisher._get_current_command = lambda _: mock_curent_command("command1") + runs_publisher._previous_current_command = mock_curent_command("command1") + + await runs_publisher._handle_current_command_change() + + assert not notification_client.publish_advise_refetch_async.called + + runs_publisher._get_current_command = lambda _: mock_curent_command("command2") + + await runs_publisher._handle_current_command_change() + + notification_client.publish_advise_refetch_async.assert_any_await( + topic=Topics.RUNS_CURRENT_COMMAND + ) + + +@pytest.mark.asyncio +async def test_handle_engine_status_change( + runs_publisher: RunsPublisher, notification_client: AsyncMock +) -> None: + """It should handle engine status changes appropriately.""" + runs_publisher._run_id = "1234" + runs_publisher._get_state_summary = MagicMock( + return_value=MagicMock(status=EngineStatus.IDLE) + ) + runs_publisher._previous_state_summary_status = EngineStatus.IDLE + + await runs_publisher._handle_engine_status_change() + + assert not notification_client.publish_advise_refetch_async.called + + runs_publisher._get_state_summary.return_value = MagicMock( + status=EngineStatus.RUNNING + ) + + await runs_publisher._handle_engine_status_change() + + notification_client.publish_advise_refetch_async.assert_any_await(topic=Topics.RUNS) + notification_client.publish_advise_refetch_async.assert_any_await( + topic=f"{Topics.RUNS}/1234" + )