Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significantly reduce websocket api connection auth phase latency #108564

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions homeassistant/components/websocket_api/auth.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Handle the auth of a connection."""
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, Final

from aiohttp.web import Request
import voluptuous as vol
from voluptuous.humanize import humanize_error

from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
Expand Down Expand Up @@ -41,9 +40,9 @@
)


def auth_invalid_message(message: str) -> dict[str, str]:
def auth_invalid_message(message: str) -> bytes:
"""Return an auth_invalid message."""
return {"type": TYPE_AUTH_INVALID, "message": message}
return json_bytes({"type": TYPE_AUTH_INVALID, "message": message})


class AuthPhase:
Expand All @@ -56,13 +55,17 @@ def __init__(
send_message: Callable[[bytes | str | dict[str, Any]], None],
cancel_ws: CALLBACK_TYPE,
request: Request,
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
) -> None:
"""Initialize the authentiated connection."""
"""Initialize the authenticated connection."""
self._hass = hass
# send_message will send a message to the client via the queue.
self._send_message = send_message
self._cancel_ws = cancel_ws
self._logger = logger
self._request = request
# send_bytes_text will directly send a message to the client.
self._send_bytes_text = send_bytes_text

async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
"""Handle authentication."""
Expand All @@ -73,34 +76,33 @@ async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
f"Auth message incorrectly formatted: {humanize_error(msg, err)}"
)
self._logger.warning(error_msg)
self._send_message(auth_invalid_message(error_msg))
await self._send_bytes_text(auth_invalid_message(error_msg))
raise Disconnect from err

if (access_token := valid_msg.get("access_token")) and (
refresh_token := await self._hass.auth.async_validate_access_token(
access_token
)
):
conn = await self._async_finish_auth(refresh_token.user, refresh_token)
conn = ActiveConnection(
self._logger,
self._hass,
self._send_message,
refresh_token.user,
refresh_token,
)
conn.subscriptions[
"auth"
] = self._hass.auth.async_register_revoke_token_callback(
refresh_token.id, self._cancel_ws
)

await self._send_bytes_text(AUTH_OK_MESSAGE)
self._logger.debug("Auth OK")
process_success_login(self._request)
return conn

self._send_message(auth_invalid_message("Invalid access token or password"))
await self._send_bytes_text(
auth_invalid_message("Invalid access token or password")
)
await process_wrong_login(self._request)
raise Disconnect

async def _async_finish_auth(
self, user: User, refresh_token: RefreshToken
) -> ActiveConnection:
"""Create an active connection."""
self._logger.debug("Auth OK")
process_success_login(self._request)
self._send_message(AUTH_OK_MESSAGE)
return ActiveConnection(
self._logger, self._hass, self._send_message, user, refresh_token
)
43 changes: 25 additions & 18 deletions homeassistant/components/websocket_api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import asyncio
from collections import deque
from collections.abc import Callable
from collections.abc import Callable, Coroutine
import datetime as dt
from functools import partial
import logging
Expand Down Expand Up @@ -116,16 +116,14 @@ def description(self) -> str:
return describe_request(request)
return "finished connection"

async def _writer(self) -> None:
async def _writer(
self, send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]]
) -> None:
"""Write outgoing messages."""
# Variables are set locally to avoid lookups in the loop
message_queue = self._message_queue
logger = self._logger
wsock = self._wsock
writer = wsock._writer # pylint: disable=protected-access
if TYPE_CHECKING:
assert writer is not None
send_str = partial(writer.send, binary=False)
loop = self._hass.loop
debug = logger.debug
is_enabled_for = logger.isEnabledFor
Expand All @@ -152,7 +150,7 @@ async def _writer(self) -> None:
):
if debug_enabled:
debug("%s: Sending %s", self.description, message)
await send_str(message)
await send_bytes_text(message)
continue

messages: list[bytes] = [message]
Expand All @@ -166,7 +164,7 @@ async def _writer(self) -> None:
coalesced_messages = b"".join((b"[", b",".join(messages), b"]"))
if debug_enabled:
debug("%s: Sending %s", self.description, coalesced_messages)
await send_str(coalesced_messages)
await send_bytes_text(coalesced_messages)
except asyncio.CancelledError:
debug("%s: Writer cancelled", self.description)
raise
Expand All @@ -186,7 +184,7 @@ def _cancel_peak_checker(self) -> None:

@callback
def _send_message(self, message: str | bytes | dict[str, Any]) -> None:
"""Send a message to the client.
"""Queue sending a message to the client.

Closes connection if the client is not reading the messages.

Expand Down Expand Up @@ -295,21 +293,23 @@ async def async_handle(self) -> web.WebSocketResponse:
EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
)

# As the webserver is now started before the start
# event we do not want to block for websocket responses
self._writer_task = asyncio.create_task(self._writer())
writer = wsock._writer # pylint: disable=protected-access
if TYPE_CHECKING:
assert writer is not None

auth = AuthPhase(logger, hass, self._send_message, self._cancel, request)
send_bytes_text = partial(writer.send, binary=False)
auth = AuthPhase(
logger, hass, self._send_message, self._cancel, request, send_bytes_text
)
connection = None
disconnect_warn = None

try:
self._send_message(AUTH_REQUIRED_MESSAGE)
await send_bytes_text(AUTH_REQUIRED_MESSAGE)

# Auth Phase
try:
async with asyncio.timeout(10):
msg = await wsock.receive()
msg = await wsock.receive(10)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
except asyncio.TimeoutError as err:
disconnect_warn = "Did not receive auth message within 10 seconds"
raise Disconnect from err
Expand All @@ -330,7 +330,13 @@ async def async_handle(self) -> web.WebSocketResponse:
if is_enabled_for(logging_debug):
debug("%s: Received %s", self.description, auth_msg_data)
connection = await auth.async_handle(auth_msg_data)
# As the webserver is now started before the start
# event we do not want to block for websocket responses
#
# We only start the writer queue after the auth phase is completed
# since there is no need to queue messages before the auth phase
self._connection = connection
self._writer_task = asyncio.create_task(self._writer(send_bytes_text))
hass.data[DATA_CONNECTIONS] = hass.data.get(DATA_CONNECTIONS, 0) + 1
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_CONNECTED)

Expand Down Expand Up @@ -370,7 +376,7 @@ async def async_handle(self) -> web.WebSocketResponse:
# added a way to set the limit, but there is no way to actually
# reach the code to set the limit, so we have to set it directly.
#
wsock._writer._limit = 2**20 # type: ignore[union-attr] # pylint: disable=protected-access
writer._limit = 2**20 # pylint: disable=protected-access
async_handle_str = connection.async_handle
async_handle_binary = connection.async_handle_binary

Expand Down Expand Up @@ -441,7 +447,8 @@ async def async_handle(self) -> web.WebSocketResponse:
# so we have another finally block to make sure we close the websocket
# if the writer gets canceled.
try:
await self._writer_task
if self._writer_task:
await self._writer_task
finally:
try:
# Make sure all error messages are written before closing
Expand Down