Skip to content

Commit

Permalink
feat: add websocket state subscription (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Jun 22, 2024
1 parent 96eb294 commit d7083ab
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 57 deletions.
54 changes: 36 additions & 18 deletions src/uiprotect/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
to_js_time,
utc_now,
)
from .websocket import Websocket
from .websocket import Websocket, WebsocketState

if sys.version_info[:2] < (3, 13):
from http import cookies
Expand Down Expand Up @@ -163,7 +163,6 @@ class BaseApiClient:
_ws_timeout: int

_is_authenticated: bool = False
_last_ws_status: bool = False
_last_token_cookie: Morsel[str] | None = None
_last_token_cookie_decode: dict[str, Any] | None = None
_session: aiohttp.ClientSession | None = None
Expand Down Expand Up @@ -275,6 +274,7 @@ def _get_websocket(self) -> Websocket:
self._update_bootstrap_soon,
self.get_session,
self._process_ws_message,
self._on_websocket_state_change,
verify=self._verify_ssl,
timeout=self._ws_timeout,
)
Expand Down Expand Up @@ -675,22 +675,6 @@ async def async_disconnect_ws(self) -> None:
await websocket.wait_closed()
self._websocket = None

def check_ws(self) -> bool:
"""Checks current state of Websocket."""
if self._websocket is None:
return False

if not self._websocket.is_connected:
log = _LOGGER.debug
if self._last_ws_status:
log = _LOGGER.warning
log("Websocket connection not active, failing back to polling")
elif not self._last_ws_status:
_LOGGER.info("Websocket re-connected successfully")

self._last_ws_status = self._websocket.is_connected
return self._last_ws_status

def _process_ws_message(self, msg: aiohttp.WSMessage) -> None:
raise NotImplementedError

Expand All @@ -700,6 +684,10 @@ def _get_last_update_id(self) -> str | None:
async def update(self) -> Bootstrap:
raise NotImplementedError

def _on_websocket_state_change(self, state: WebsocketState) -> None:
"""Websocket state changed."""
_LOGGER.debug("Websocket state changed: %s", state)


class ProtectApiClient(BaseApiClient):
"""
Expand Down Expand Up @@ -736,6 +724,7 @@ class ProtectApiClient(BaseApiClient):
_subscribed_models: set[ModelType]
_ignore_stats: bool
_ws_subscriptions: list[Callable[[WSSubscriptionMessage], None]]
_ws_state_subscriptions: list[Callable[[WebsocketState], None]]
_bootstrap: Bootstrap | None = None
_last_update_dt: datetime | None = None
_connection_host: IPv4Address | IPv6Address | str | None = None
Expand Down Expand Up @@ -778,6 +767,7 @@ def __init__(
self._subscribed_models = subscribed_models or set()
self._ignore_stats = ignore_stats
self._ws_subscriptions = []
self._ws_state_subscriptions = []
self.ignore_unadopted = ignore_unadopted
self._update_lock = asyncio.Lock()

Expand Down Expand Up @@ -1140,6 +1130,34 @@ def _unsubscribe_websocket(
if not self._ws_subscriptions:
self._get_websocket().stop()

def subscribe_websocket_state(
self,
ws_callback: Callable[[WebsocketState], None],
) -> Callable[[], None]:
"""
Subscribe to websocket state changes.
Returns a callback that will unsubscribe.
"""
self._ws_state_subscriptions.append(ws_callback)
return partial(self._unsubscribe_websocket_state, ws_callback)

def _unsubscribe_websocket_state(
self,
ws_callback: Callable[[WebsocketState], None],
) -> None:
"""Unsubscribe to websocket state changes."""
self._ws_state_subscriptions.remove(ws_callback)

def _on_websocket_state_change(self, state: WebsocketState) -> None:
"""Websocket state changed."""
super()._on_websocket_state_change(state)
for sub in self._ws_state_subscriptions:
try:
sub(state)
except Exception:
_LOGGER.exception("Exception while running websocket state handler")

async def get_bootstrap(self) -> Bootstrap:
"""
Gets bootstrap object from UFP instance
Expand Down
50 changes: 39 additions & 11 deletions src/uiprotect/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import logging
from collections.abc import Awaitable, Callable, Coroutine
from enum import Enum
from http import HTTPStatus
from typing import Any, Optional

Expand All @@ -28,6 +29,11 @@
_CLOSE_MESSAGE_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED}


class WebsocketState(Enum):
CONNECTED = True
DISCONNECTED = False


class Websocket:
"""UniFi Protect Websocket manager."""

Expand All @@ -44,6 +50,7 @@ def __init__(
update_bootstrap: UpdateBootstrapCallbackType,
get_session: GetSessionCallbackType,
subscription: Callable[[WSMessage], None],
state_callback: Callable[[WebsocketState], None],
*,
timeout: float = 30.0,
backoff: int = 10,
Expand All @@ -59,10 +66,12 @@ def __init__(
self._update_bootstrap = update_bootstrap
self._subscription = subscription
self._seen_non_close_message = False
self._websocket_state = state_callback
self._current_state: WebsocketState = WebsocketState.DISCONNECTED

@property
def is_connected(self) -> bool:
"""Return if the websocket is connected."""
"""Return if the websocket is connected and has received a valid message."""
return self._ws_connection is not None and not self._ws_connection.closed

async def _websocket_loop(self) -> None:
Expand Down Expand Up @@ -92,11 +101,19 @@ async def _websocket_loop(self) -> None:
except Exception:
_LOGGER.exception("Unexpected error in websocket loop")

self._state_changed(WebsocketState.DISCONNECTED)
if self._running is False:
break
_LOGGER.debug("Reconnecting websocket in %s seconds", backoff)
await asyncio.sleep(self.backoff)

def _state_changed(self, state: WebsocketState) -> None:
"""State changed."""
if self._current_state is state:
return
self._current_state = state
self._websocket_state(state)

async def _websocket_inner_loop(self, url: URL) -> None:
_LOGGER.debug("Connecting WS to %s", url)
await self._attempt_auth(False)
Expand All @@ -119,7 +136,9 @@ async def _websocket_inner_loop(self, url: URL) -> None:
_LOGGER.debug("Websocket closed: %s", msg)
break

self._seen_non_close_message = True
if not self._seen_non_close_message:
self._seen_non_close_message = True
self._state_changed(WebsocketState.CONNECTED)
try:
self._subscription(msg)
except Exception:
Expand Down Expand Up @@ -166,21 +185,30 @@ def stop(self) -> None:
if self._websocket_loop_task:
self._websocket_loop_task.cancel()
self._running = False
self._stop_task = asyncio.create_task(self._stop())
ws_connection = self._ws_connection
websocket_loop_task = self._websocket_loop_task
self._ws_connection = None
self._websocket_loop_task = None
self._stop_task = asyncio.create_task(
self._stop(ws_connection, websocket_loop_task)
)
self._state_changed(WebsocketState.DISCONNECTED)

async def wait_closed(self) -> None:
"""Wait for the websocket to close."""
if self._stop_task:
if self._stop_task and not self._stop_task.done():
with contextlib.suppress(asyncio.CancelledError):
await self._stop_task
self._stop_task = None

async def _stop(self) -> None:
async def _stop(
self,
ws_connection: ClientWebSocketResponse | None,
websocket_loop_task: asyncio.Task[None] | None,
) -> None:
"""Stop the websocket."""
if self._ws_connection:
await self._ws_connection.close()
self._ws_connection = None
if self._websocket_loop_task:
if ws_connection:
await ws_connection.close()
if websocket_loop_task:
with contextlib.suppress(asyncio.CancelledError):
await self._websocket_loop_task
self._websocket_loop_task = None
await websocket_loop_task
53 changes: 25 additions & 28 deletions tests/test_api_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
WSSubscriptionMessage,
)
from uiprotect.utils import print_ws_stat_summary, to_js_time, utc_now
from uiprotect.websocket import WebsocketState

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -515,55 +516,51 @@ async def test_check_ws_connected(
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.DEBUG)

unsub = protect_client_ws.subscribe_websocket(lambda _: None)
await asyncio.sleep(0)
active_ws = protect_client_ws.check_ws()

assert active_ws is True

assert "Websocket re-connected successfully" in caplog.text
while not protect_client_ws._websocket.is_connected:
await asyncio.sleep(0.01)
assert protect_client_ws._websocket.is_connected
unsub()


@pytest.mark.asyncio()
async def test_check_ws_no_ws_initial(
protect_client: ProtectApiClient,
async def test_check_ws_connected_state_callback(
protect_client_ws: ProtectApiClient,
caplog: pytest.LogCaptureFixture,
):
websocket = protect_client_ws._websocket
assert not websocket.is_connected

caplog.set_level(logging.DEBUG)
states: list[bool] = []

await protect_client.async_disconnect_ws()
protect_client._last_ws_status = True
def _on_state(state: bool):
states.append(state)

active_ws = protect_client.check_ws()
unsub_state = protect_client_ws.subscribe_websocket_state(_on_state)
unsub = protect_client_ws.subscribe_websocket(lambda _: None)
while websocket._current_state is not WebsocketState.CONNECTED:
await asyncio.sleep(0.01)

assert active_ws is False
assert states == [WebsocketState.CONNECTED]
await protect_client_ws.async_disconnect_ws()
while websocket._current_state is not WebsocketState.DISCONNECTED:
await asyncio.sleep(0.01)

expected_logs = [
"Disconnecting websocket...",
]
assert expected_logs == [rec.message for rec in caplog.records]
assert states == [WebsocketState.CONNECTED, WebsocketState.DISCONNECTED]
unsub()
unsub_state()


@pytest.mark.asyncio()
async def test_check_ws_no_ws(
async def test_check_ws_no_ws_initial(
protect_client: ProtectApiClient,
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.DEBUG)

await protect_client.async_disconnect_ws()
protect_client._last_ws_status = False

active_ws = protect_client.check_ws()

assert active_ws is False

expected_logs = [
"Disconnecting websocket...",
]
assert expected_logs == [rec.message for rec in caplog.records]
assert not protect_client._websocket


@pytest.mark.skipif(not TEST_CAMERA_EXISTS, reason="Missing testdata")
Expand Down

0 comments on commit d7083ab

Please sign in to comment.