diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 6bb0c154540081..ab228e32a52bff 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -32,6 +32,11 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import storage import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.http import ( + KEY_AUTHENTICATED, # noqa: F401 + HomeAssistantView, + current_request, +) from homeassistant.helpers.network import NoURLAvailableError, get_url from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass @@ -41,20 +46,14 @@ from .auth import async_setup_auth from .ban import setup_bans -from .const import ( # noqa: F401 - KEY_AUTHENTICATED, - KEY_HASS, - KEY_HASS_REFRESH_TOKEN_ID, - KEY_HASS_USER, -) +from .const import KEY_HASS, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401 from .cors import setup_cors from .decorators import require_admin # noqa: F401 from .forwarded import async_setup_forwarded from .headers import setup_headers -from .request_context import current_request, setup_request_context +from .request_context import setup_request_context from .security_filter import setup_security_filter from .static import CACHE_HEADERS, CachingStaticResource -from .view import HomeAssistantView from .web_runner import HomeAssistantTCPSite DOMAIN: Final = "http" diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 99d38bf582edc2..640d899924e3d6 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -20,13 +20,13 @@ from homeassistant.auth.models import User from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.http import current_request from homeassistant.helpers.json import json_bytes from homeassistant.helpers.network import is_cloud_connection from homeassistant.helpers.storage import Store from homeassistant.util.network import is_local from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER -from .request_context import current_request _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index 62569495ba70f7..0b720b078b989a 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -15,7 +15,6 @@ from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized import voluptuous as vol -from homeassistant.components import persistent_notification from homeassistant.config import load_yaml_config_file from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -128,6 +127,10 @@ async def process_wrong_login(request: Request) -> None: _LOGGER.warning(log_msg) + # Circular import with websocket_api + # pylint: disable=import-outside-toplevel + from homeassistant.components import persistent_notification + persistent_notification.async_create( hass, notification_msg, "Login attempt failed", NOTIFICATION_ID_LOGIN ) diff --git a/homeassistant/components/http/const.py b/homeassistant/components/http/const.py index df27122b64a9e5..090e5234aebd22 100644 --- a/homeassistant/components/http/const.py +++ b/homeassistant/components/http/const.py @@ -1,7 +1,8 @@ """HTTP specific constants.""" from typing import Final -KEY_AUTHENTICATED: Final = "ha_authenticated" +from homeassistant.helpers.http import KEY_AUTHENTICATED # noqa: F401 + KEY_HASS: Final = "hass" KEY_HASS_USER: Final = "hass_user" KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id" diff --git a/homeassistant/components/http/request_context.py b/homeassistant/components/http/request_context.py index 6e036b9cdc8dbf..b516b63dc5c928 100644 --- a/homeassistant/components/http/request_context.py +++ b/homeassistant/components/http/request_context.py @@ -7,10 +7,7 @@ from aiohttp.web import Application, Request, StreamResponse, middleware from homeassistant.core import callback - -current_request: ContextVar[Request | None] = ContextVar( - "current_request", default=None -) +from homeassistant.helpers.http import current_request # noqa: F401 @callback diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index 1be3d761a3b215..ce02879dbb37c9 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -1,180 +1,7 @@ """Support for views.""" from __future__ import annotations -import asyncio -from collections.abc import Awaitable, Callable -from http import HTTPStatus -import logging -from typing import Any - -from aiohttp import web -from aiohttp.typedefs import LooseHeaders -from aiohttp.web_exceptions import ( - HTTPBadRequest, - HTTPInternalServerError, - HTTPUnauthorized, -) -from aiohttp.web_urldispatcher import AbstractRoute -import voluptuous as vol - -from homeassistant import exceptions -from homeassistant.const import CONTENT_TYPE_JSON -from homeassistant.core import Context, HomeAssistant, is_callback -from homeassistant.helpers.json import ( - find_paths_unserializable_data, - json_bytes, - json_dumps, +from homeassistant.helpers.http import ( # noqa: F401 + HomeAssistantView, + request_handler_factory, ) -from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data - -from .const import KEY_AUTHENTICATED - -_LOGGER = logging.getLogger(__name__) - - -class HomeAssistantView: - """Base view for all views.""" - - url: str | None = None - extra_urls: list[str] = [] - # Views inheriting from this class can override this - requires_auth = True - cors_allowed = False - - @staticmethod - def context(request: web.Request) -> Context: - """Generate a context from a request.""" - if (user := request.get("hass_user")) is None: - return Context() - - return Context(user_id=user.id) - - @staticmethod - def json( - result: Any, - status_code: HTTPStatus | int = HTTPStatus.OK, - headers: LooseHeaders | None = None, - ) -> web.Response: - """Return a JSON response.""" - try: - msg = json_bytes(result) - except JSON_ENCODE_EXCEPTIONS as err: - _LOGGER.error( - "Unable to serialize to JSON. Bad data found at %s", - format_unserializable_data( - find_paths_unserializable_data(result, dump=json_dumps) - ), - ) - raise HTTPInternalServerError from err - response = web.Response( - body=msg, - content_type=CONTENT_TYPE_JSON, - status=int(status_code), - headers=headers, - zlib_executor_size=32768, - ) - response.enable_compression() - return response - - def json_message( - self, - message: str, - status_code: HTTPStatus | int = HTTPStatus.OK, - message_code: str | None = None, - headers: LooseHeaders | None = None, - ) -> web.Response: - """Return a JSON message response.""" - data = {"message": message} - if message_code is not None: - data["code"] = message_code - return self.json(data, status_code, headers=headers) - - def register( - self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher - ) -> None: - """Register the view with a router.""" - assert self.url is not None, "No url set for view" - urls = [self.url] + self.extra_urls - routes: list[AbstractRoute] = [] - - for method in ("get", "post", "delete", "put", "patch", "head", "options"): - if not (handler := getattr(self, method, None)): - continue - - handler = request_handler_factory(hass, self, handler) - - for url in urls: - routes.append(router.add_route(method, url, handler)) - - # Use `get` because CORS middleware is not be loaded in emulated_hue - if self.cors_allowed: - allow_cors = app.get("allow_all_cors") - else: - allow_cors = app.get("allow_configured_cors") - - if allow_cors: - for route in routes: - allow_cors(route) - - -def request_handler_factory( - hass: HomeAssistant, view: HomeAssistantView, handler: Callable -) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: - """Wrap the handler classes.""" - is_coroutinefunction = asyncio.iscoroutinefunction(handler) - assert is_coroutinefunction or is_callback( - handler - ), "Handler should be a coroutine or a callback." - - async def handle(request: web.Request) -> web.StreamResponse: - """Handle incoming request.""" - if hass.is_stopping: - return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE) - - authenticated = request.get(KEY_AUTHENTICATED, False) - - if view.requires_auth and not authenticated: - raise HTTPUnauthorized() - - if _LOGGER.isEnabledFor(logging.DEBUG): - _LOGGER.debug( - "Serving %s to %s (auth: %s)", - request.path, - request.remote, - authenticated, - ) - - try: - if is_coroutinefunction: - result = await handler(request, **request.match_info) - else: - result = handler(request, **request.match_info) - except vol.Invalid as err: - raise HTTPBadRequest() from err - except exceptions.ServiceNotFound as err: - raise HTTPInternalServerError() from err - except exceptions.Unauthorized as err: - raise HTTPUnauthorized() from err - - if isinstance(result, web.StreamResponse): - # The method handler returned a ready-made Response, how nice of it - return result - - status_code = HTTPStatus.OK - if isinstance(result, tuple): - result, status_code = result - - if isinstance(result, bytes): - return web.Response(body=result, status=status_code) - - if isinstance(result, str): - return web.Response(text=result, status=status_code) - - if result is None: - return web.Response(body=b"", status=status_code) - - raise TypeError( - f"Result should be None, string, bytes or StreamResponse. Got: {result}" - ) - - return handle diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 280ff41c56e3bd..aa7bcefadae456 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -9,9 +9,9 @@ import voluptuous as vol from homeassistant.auth.models import RefreshToken, User -from homeassistant.components.http import current_request from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, Unauthorized +from homeassistant.helpers.http import current_request from homeassistant.util.json import JsonValueType from . import const, messages diff --git a/homeassistant/helpers/http.py b/homeassistant/helpers/http.py new file mode 100644 index 00000000000000..63ff173a3a0266 --- /dev/null +++ b/homeassistant/helpers/http.py @@ -0,0 +1,184 @@ +"""Helper to track the current http request.""" +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from http import HTTPStatus +import logging +from typing import Any, Final + +from aiohttp import web +from aiohttp.typedefs import LooseHeaders +from aiohttp.web import Request +from aiohttp.web_exceptions import ( + HTTPBadRequest, + HTTPInternalServerError, + HTTPUnauthorized, +) +from aiohttp.web_urldispatcher import AbstractRoute +import voluptuous as vol + +from homeassistant import exceptions +from homeassistant.const import CONTENT_TYPE_JSON +from homeassistant.core import Context, HomeAssistant, is_callback +from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data + +from .json import find_paths_unserializable_data, json_bytes, json_dumps + +_LOGGER = logging.getLogger(__name__) + + +KEY_AUTHENTICATED: Final = "ha_authenticated" + +current_request: ContextVar[Request | None] = ContextVar( + "current_request", default=None +) + + +def request_handler_factory( + hass: HomeAssistant, view: HomeAssistantView, handler: Callable +) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: + """Wrap the handler classes.""" + is_coroutinefunction = asyncio.iscoroutinefunction(handler) + assert is_coroutinefunction or is_callback( + handler + ), "Handler should be a coroutine or a callback." + + async def handle(request: web.Request) -> web.StreamResponse: + """Handle incoming request.""" + if hass.is_stopping: + return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE) + + authenticated = request.get(KEY_AUTHENTICATED, False) + + if view.requires_auth and not authenticated: + raise HTTPUnauthorized() + + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug( + "Serving %s to %s (auth: %s)", + request.path, + request.remote, + authenticated, + ) + + try: + if is_coroutinefunction: + result = await handler(request, **request.match_info) + else: + result = handler(request, **request.match_info) + except vol.Invalid as err: + raise HTTPBadRequest() from err + except exceptions.ServiceNotFound as err: + raise HTTPInternalServerError() from err + except exceptions.Unauthorized as err: + raise HTTPUnauthorized() from err + + if isinstance(result, web.StreamResponse): + # The method handler returned a ready-made Response, how nice of it + return result + + status_code = HTTPStatus.OK + if isinstance(result, tuple): + result, status_code = result + + if isinstance(result, bytes): + return web.Response(body=result, status=status_code) + + if isinstance(result, str): + return web.Response(text=result, status=status_code) + + if result is None: + return web.Response(body=b"", status=status_code) + + raise TypeError( + f"Result should be None, string, bytes or StreamResponse. Got: {result}" + ) + + return handle + + +class HomeAssistantView: + """Base view for all views.""" + + url: str | None = None + extra_urls: list[str] = [] + # Views inheriting from this class can override this + requires_auth = True + cors_allowed = False + + @staticmethod + def context(request: web.Request) -> Context: + """Generate a context from a request.""" + if (user := request.get("hass_user")) is None: + return Context() + + return Context(user_id=user.id) + + @staticmethod + def json( + result: Any, + status_code: HTTPStatus | int = HTTPStatus.OK, + headers: LooseHeaders | None = None, + ) -> web.Response: + """Return a JSON response.""" + try: + msg = json_bytes(result) + except JSON_ENCODE_EXCEPTIONS as err: + _LOGGER.error( + "Unable to serialize to JSON. Bad data found at %s", + format_unserializable_data( + find_paths_unserializable_data(result, dump=json_dumps) + ), + ) + raise HTTPInternalServerError from err + response = web.Response( + body=msg, + content_type=CONTENT_TYPE_JSON, + status=int(status_code), + headers=headers, + zlib_executor_size=32768, + ) + response.enable_compression() + return response + + def json_message( + self, + message: str, + status_code: HTTPStatus | int = HTTPStatus.OK, + message_code: str | None = None, + headers: LooseHeaders | None = None, + ) -> web.Response: + """Return a JSON message response.""" + data = {"message": message} + if message_code is not None: + data["code"] = message_code + return self.json(data, status_code, headers=headers) + + def register( + self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher + ) -> None: + """Register the view with a router.""" + assert self.url is not None, "No url set for view" + urls = [self.url] + self.extra_urls + routes: list[AbstractRoute] = [] + + for method in ("get", "post", "delete", "put", "patch", "head", "options"): + if not (handler := getattr(self, method, None)): + continue + + handler = request_handler_factory(hass, self, handler) + + for url in urls: + routes.append(router.add_route(method, url, handler)) + + # Use `get` because CORS middleware is not be loaded in emulated_hue + if self.cors_allowed: + allow_cors = app.get("allow_all_cors") + else: + allow_cors = app.get("allow_configured_cors") + + if allow_cors: + for route in routes: + allow_cors(route) diff --git a/tests/test_circular_imports.py b/tests/test_circular_imports.py new file mode 100644 index 00000000000000..1c5157b74e1e18 --- /dev/null +++ b/tests/test_circular_imports.py @@ -0,0 +1,39 @@ +"""Test to check for circular imports in core components.""" +import asyncio +import sys + +import pytest + +from homeassistant.bootstrap import ( + CORE_INTEGRATIONS, + DEBUGGER_INTEGRATIONS, + DEFAULT_INTEGRATIONS, + FRONTEND_INTEGRATIONS, + LOGGING_INTEGRATIONS, + RECORDER_INTEGRATIONS, + STAGE_1_INTEGRATIONS, +) + + +@pytest.mark.timeout(30) # cloud can take > 9s +@pytest.mark.parametrize( + "component", + sorted( + { + *DEBUGGER_INTEGRATIONS, + *CORE_INTEGRATIONS, + *LOGGING_INTEGRATIONS, + *FRONTEND_INTEGRATIONS, + *RECORDER_INTEGRATIONS, + *STAGE_1_INTEGRATIONS, + *DEFAULT_INTEGRATIONS, + } + ), +) +async def test_circular_imports(component: str) -> None: + """Check that components can be imported without circular imports.""" + process = await asyncio.create_subprocess_exec( + sys.executable, "-c", f"import homeassistant.components.{component}" + ) + await process.communicate() + assert process.returncode == 0