diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 27eed49e5ca503..33aeb283f5a05c 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -626,7 +626,7 @@ def websocket_delete_all_refresh_tokens( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: """Handle delete all refresh tokens request.""" - current_refresh_token: RefreshToken + current_refresh_token: RefreshToken | None = None remove_failed = False token_type = msg.get("token_type") delete_current_token = msg.get("delete_current_token") @@ -654,7 +654,7 @@ def websocket_delete_all_refresh_tokens( else: connection.send_result(msg["id"], {}) - async def _delete_current_token_soon() -> None: + async def _delete_current_token_soon(current_refresh_token: RefreshToken) -> None: """Delete the current token after a delay. We do not want to delete the current token immediately as it will @@ -675,13 +675,15 @@ async def _delete_current_token_soon() -> None: # the token right away. hass.auth.async_remove_refresh_token(current_refresh_token) - if delete_current_token and ( - not limit_token_types or current_refresh_token.token_type == token_type + if ( + delete_current_token + and current_refresh_token + and (not limit_token_types or current_refresh_token.token_type == token_type) ): # Deleting the token will close the connection so we need # to do it with a delay in a tracked task to ensure it still # happens if Home Assistant is shutting down. - hass.async_create_task(_delete_current_token_soon()) + hass.async_create_task(_delete_current_token_soon(current_refresh_token)) @websocket_api.websocket_command( diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 75971b1ed1d5f1..a4db676ffe38b4 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -10,6 +10,7 @@ from ipaddress import IPv4Network, IPv6Network, ip_network import logging import os +from pathlib import Path import socket import ssl from tempfile import NamedTemporaryFile @@ -33,6 +34,7 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, + HASSIO_USER_NAME, SERVER_PORT, ) from homeassistant.core import Event, HomeAssistant, callback @@ -69,7 +71,7 @@ from .request_context import setup_request_context from .security_filter import setup_security_filter from .static import CACHE_HEADERS, CachingStaticResource -from .web_runner import HomeAssistantTCPSite +from .web_runner import HomeAssistantTCPSite, HomeAssistantUnixSite CONF_SERVER_HOST: Final = "server_host" CONF_SERVER_PORT: Final = "server_port" @@ -235,6 +237,17 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: source_ip_task = create_eager_task(async_get_source_ip(hass)) + supervisor_unix_socket_path: Path | None = None + if socket_env := os.environ.get("SUPERVISOR_CORE_API_SOCKET"): + socket_path = Path(socket_env) + if socket_path.is_absolute(): + supervisor_unix_socket_path = socket_path + else: + _LOGGER.error( + "Invalid Supervisor Unix socket path %s: path must be absolute", + socket_env, + ) + server = HomeAssistantHTTP( hass, server_host=server_host, @@ -244,6 +257,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ssl_key=ssl_key, trusted_proxies=trusted_proxies, ssl_profile=ssl_profile, + supervisor_unix_socket_path=supervisor_unix_socket_path, ) await server.async_initialize( cors_origins=cors_origins, @@ -267,6 +281,21 @@ async def start_server(*_: Any) -> None: async_when_setup_or_start(hass, "frontend", start_server) + if server.supervisor_unix_socket_path is not None: + + async def start_supervisor_unix_socket(*_: Any) -> None: + """Start the Unix socket after the Supervisor user is available.""" + if any( + user + for user in await hass.auth.async_get_users() + if user.system_generated and user.name == HASSIO_USER_NAME + ): + await server.async_start_supervisor_unix_socket() + else: + _LOGGER.error("Supervisor user not found; not starting Unix socket") + + async_when_setup_or_start(hass, "hassio", start_supervisor_unix_socket) + hass.http = server local_ip = await source_ip_task @@ -366,6 +395,7 @@ def __init__( server_port: int, trusted_proxies: list[IPv4Network | IPv6Network], ssl_profile: str, + supervisor_unix_socket_path: Path | None = None, ) -> None: """Initialize the HTTP Home Assistant server.""" self.app = HomeAssistantApplication( @@ -384,8 +414,10 @@ def __init__( self.server_port = server_port self.trusted_proxies = trusted_proxies self.ssl_profile = ssl_profile + self.supervisor_unix_socket_path = supervisor_unix_socket_path self.runner: web.AppRunner | None = None self.site: HomeAssistantTCPSite | None = None + self.supervisor_site: HomeAssistantUnixSite | None = None self.context: ssl.SSLContext | None = None async def async_initialize( @@ -610,6 +642,33 @@ def _create_emergency_ssl_context(self) -> ssl.SSLContext: context.load_cert_chain(cert_pem.name, key_pem.name) return context + async def async_start_supervisor_unix_socket(self) -> None: + """Start listening on the Unix socket. + + This is called separately from start() to delay serving the Unix + socket until the Supervisor user exists (created by the hassio + integration). Without this delay, Supervisor could connect before + its user is available and receive 401 responses it won't retry. + """ + if self.supervisor_unix_socket_path is None or self.runner is None: + return + self.supervisor_site = HomeAssistantUnixSite( + self.runner, self.supervisor_unix_socket_path + ) + try: + await self.supervisor_site.start() + except OSError as error: + _LOGGER.error( + "Failed to create HTTP server on unix socket %s: %s", + self.supervisor_unix_socket_path, + error, + ) + self.supervisor_site = None + else: + _LOGGER.info( + "Now listening on unix socket %s", self.supervisor_unix_socket_path + ) + async def start(self) -> None: """Start the aiohttp server.""" # Aiohttp freezes apps after start so that no changes can be made. @@ -637,6 +696,19 @@ async def start(self) -> None: async def stop(self) -> None: """Stop the aiohttp server.""" + if self.supervisor_site is not None: + await self.supervisor_site.stop() + if self.supervisor_unix_socket_path is not None: + try: + await self.hass.async_add_executor_job( + self.supervisor_unix_socket_path.unlink, True + ) + except OSError as err: + _LOGGER.warning( + "Could not remove Supervisor unix socket %s: %s", + self.supervisor_unix_socket_path, + err, + ) if self.site is not None: await self.site.stop() if self.runner is not None: diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 227ee074439e39..50b3812dd7dd54 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -11,7 +11,13 @@ from typing import Any, Final from aiohttp import hdrs -from aiohttp.web import Application, Request, StreamResponse, middleware +from aiohttp.web import ( + Application, + HTTPInternalServerError, + Request, + StreamResponse, + middleware, +) import jwt from jwt import api_jws from yarl import URL @@ -20,6 +26,7 @@ from homeassistant.auth.const import GROUP_ID_READ_ONLY from homeassistant.auth.models import User from homeassistant.components import websocket_api +from homeassistant.const import HASSIO_USER_NAME from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.http import current_request from homeassistant.helpers.json import json_bytes @@ -27,7 +34,12 @@ from homeassistant.helpers.storage import Store from homeassistant.util.network import is_local -from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER +from .const import ( + KEY_AUTHENTICATED, + KEY_HASS_REFRESH_TOKEN_ID, + KEY_HASS_USER, + is_supervisor_unix_socket_request, +) _LOGGER = logging.getLogger(__name__) @@ -57,7 +69,9 @@ def async_sign_path( if refresh_token_id is None: if use_content_user: refresh_token_id = hass.data[STORAGE_KEY] - elif connection := websocket_api.current_connection.get(): + elif ( + connection := websocket_api.current_connection.get() + ) and connection.refresh_token_id: refresh_token_id = connection.refresh_token_id elif ( request := current_request.get() @@ -117,7 +131,7 @@ def async_user_not_allowed_do_auth( return "User cannot authenticate remotely" -async def async_setup_auth( +async def async_setup_auth( # noqa: C901 hass: HomeAssistant, app: Application, ) -> None: @@ -207,6 +221,41 @@ def async_validate_signed_request(request: Request) -> bool: request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id return True + supervisor_user_id: str | None = None + + async def async_authenticate_supervisor_unix_socket(request: Request) -> bool: + """Authenticate a request from a Unix socket as the Supervisor user. + + The Unix Socket is dedicated and only available to Supervisor. To + avoid the extra overhead and round trips for the authentication and + refresh tokens, we directly authenticate requests from the socket as + the Supervisor user. + """ + nonlocal supervisor_user_id + + # Fast path: use cached user ID + if supervisor_user_id is not None: + if user := await hass.auth.async_get_user(supervisor_user_id): + request[KEY_HASS_USER] = user + return True + supervisor_user_id = None + + # Slow path: find the Supervisor user by name + for user in await hass.auth.async_get_users(): + if user.system_generated and user.name == HASSIO_USER_NAME: + supervisor_user_id = user.id + # Not setting KEY_HASS_REFRESH_TOKEN_ID since Supervisor user + # doesn't use refresh tokens. + request[KEY_HASS_USER] = user + return True + + # The Unix socket should not be serving before the hassio integration + # has created the Supervisor user. If we get here, something is wrong. + _LOGGER.error( + "Supervisor user not found; cannot authenticate Unix socket request" + ) + raise HTTPInternalServerError + @middleware async def auth_middleware( request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] @@ -214,7 +263,11 @@ async def auth_middleware( """Authenticate as middleware.""" authenticated = False - if hdrs.AUTHORIZATION in request.headers and async_validate_auth_header( + if is_supervisor_unix_socket_request(request): + authenticated = await async_authenticate_supervisor_unix_socket(request) + auth_type = "supervisor unix socket" + + elif hdrs.AUTHORIZATION in request.headers and async_validate_auth_header( request ): authenticated = True @@ -233,7 +286,7 @@ async def auth_middleware( if authenticated and _LOGGER.isEnabledFor(logging.DEBUG): _LOGGER.debug( "Authenticated %s for %s using %s", - request.remote, + request.remote or "unknown remote", request.path, auth_type, ) diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index e9ebdb6bfc7e14..e2ec1ad95a3139 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -30,7 +30,7 @@ from homeassistant.helpers.hassio import get_supervisor_ip, is_hassio from homeassistant.util import dt as dt_util, yaml as yaml_util -from .const import KEY_HASS +from .const import KEY_HASS, is_supervisor_unix_socket_request from .view import HomeAssistantView _LOGGER: Final = logging.getLogger(__name__) @@ -72,6 +72,10 @@ async def ban_middleware( request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] ) -> StreamResponse: """IP Ban middleware.""" + # Unix socket connections are trusted, skip ban checks + if is_supervisor_unix_socket_request(request): + return await handler(request) + if (ban_manager := request.app.get(KEY_BAN_MANAGER)) is None: _LOGGER.error("IP Ban middleware loaded but banned IPs not loaded") return await handler(request) diff --git a/homeassistant/components/http/const.py b/homeassistant/components/http/const.py index 1a5d7a603d75f3..c89751a62affcc 100644 --- a/homeassistant/components/http/const.py +++ b/homeassistant/components/http/const.py @@ -2,9 +2,23 @@ from typing import Final +from aiohttp.web import Request + from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401 DOMAIN: Final = "http" KEY_HASS_USER: Final = "hass_user" KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id" + + +def is_supervisor_unix_socket_request(request: Request) -> bool: + """Check if request arrived over the Supervisor Unix socket.""" + if (transport := request.transport) is None: + return False + if (http := request.app[KEY_HASS].http) is None or ( + supervisor_path := http.supervisor_unix_socket_path + ) is None: + return False + sockname: str | None = transport.get_extra_info("sockname") + return sockname == str(supervisor_path) diff --git a/homeassistant/components/http/web_runner.py b/homeassistant/components/http/web_runner.py index f633433c9e4d54..a28b69ba9d3a12 100644 --- a/homeassistant/components/http/web_runner.py +++ b/homeassistant/components/http/web_runner.py @@ -3,6 +3,8 @@ from __future__ import annotations import asyncio +from pathlib import Path +import socket from ssl import SSLContext from aiohttp import web @@ -68,3 +70,62 @@ async def start(self) -> None: reuse_address=self._reuse_address, reuse_port=self._reuse_port, ) + + +class HomeAssistantUnixSite(web.BaseSite): + """HomeAssistant specific aiohttp UnixSite. + + Listens on a Unix socket for local inter-process communication, + used for Supervisor to Core communication. + """ + + __slots__ = ("_path",) + + def __init__( + self, + runner: web.BaseRunner, + path: Path, + *, + backlog: int = 128, + ) -> None: + """Initialize HomeAssistantUnixSite.""" + super().__init__( + runner, + backlog=backlog, + ) + self._path = path + + @property + def name(self) -> str: + """Return server URL.""" + return f"http://unix:{self._path}:" + + def _create_unix_socket(self) -> socket.socket: + """Create and bind a Unix domain socket. + + Performs blocking filesystem I/O (mkdir, unlink, chmod) and is + intended to be run in an executor. Permissions are set after bind + but before the socket is handed to the event loop, so no + connections can arrive on an unrestricted socket. + """ + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.unlink(missing_ok=True) + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.bind(str(self._path)) + self._path.chmod(0o600) + except OSError: + sock.close() + raise + return sock + + async def start(self) -> None: + """Start server.""" + await super().start() + loop = asyncio.get_running_loop() + sock = await loop.run_in_executor(None, self._create_unix_socket) + server = self._runner.server + assert server is not None + self._server = await loop.create_unix_server( + server, sock=sock, backlog=self._backlog + ) diff --git a/homeassistant/components/onboarding/views.py b/homeassistant/components/onboarding/views.py index 4e2f6a18e0d602..b78b789d5e2171 100644 --- a/homeassistant/components/onboarding/views.py +++ b/homeassistant/components/onboarding/views.py @@ -283,7 +283,10 @@ class IntegrationOnboardingView(_BaseOnboardingStepView): async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: """Handle token creation.""" hass = request.app[KEY_HASS] - refresh_token_id = request[KEY_HASS_REFRESH_TOKEN_ID] + if not (refresh_token_id := request.get(KEY_HASS_REFRESH_TOKEN_ID)): + return self.json_message( + "Refresh token not available", HTTPStatus.FORBIDDEN + ) async with self._lock: if self._async_is_done(): diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index a15f76632c1ce1..b0e319bbce5ad4 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -10,6 +10,7 @@ from voluptuous.humanize import humanize_error from homeassistant.components.http.ban import process_success_login, process_wrong_login +from homeassistant.components.http.const import KEY_HASS_USER from homeassistant.const import __version__ from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.helpers.json import json_bytes @@ -68,6 +69,19 @@ def __init__( # send_bytes_text will directly send a message to the client. self._send_bytes_text = send_bytes_text + async def async_handle_supervisor_unix_socket(self) -> ActiveConnection: + """Handle a pre-authenticated Unix socket connection.""" + conn = ActiveConnection( + self._logger, + self._hass, + self._send_message, + self._request[KEY_HASS_USER], + refresh_token=None, + ) + await self._send_bytes_text(AUTH_OK_MESSAGE) + self._logger.debug("Auth OK (unix socket)") + return conn + async def async_handle(self, msg: JsonValueType) -> ActiveConnection: """Handle authentication.""" try: diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 12473c8625580e..dad8ebe5686e24 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -59,14 +59,14 @@ def __init__( hass: HomeAssistant, send_message: Callable[[bytes | str | dict[str, Any]], None], user: User, - refresh_token: RefreshToken, + refresh_token: RefreshToken | None, ) -> None: """Initialize an active connection.""" self.logger = logger self.hass = hass self.send_message = send_message self.user = user - self.refresh_token_id = refresh_token.id + self.refresh_token_id = refresh_token.id if refresh_token else None self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 self.can_coalesce = False diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index 0e9e0eb69330c9..27280f46516a9f 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -14,6 +14,7 @@ from aiohttp.http_websocket import WebSocketWriter from homeassistant.components.http import KEY_HASS, HomeAssistantView +from homeassistant.components.http.const import is_supervisor_unix_socket_request from homeassistant.const import EVENT_HOMEASSISTANT_STOP, EVENT_LOGGING_CHANGED from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_send @@ -36,12 +37,12 @@ from .messages import message_to_json_bytes from .util import describe_request -CLOSE_MSG_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING} -AUTH_MESSAGE_TIMEOUT = 10 # seconds - if TYPE_CHECKING: from .connection import ActiveConnection +CLOSE_MSG_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING} +AUTH_MESSAGE_TIMEOUT = 10 # seconds + _WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection") @@ -386,37 +387,45 @@ async def _async_handle_auth_phase( send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]], ) -> ActiveConnection: """Handle the auth phase of the websocket connection.""" - await send_bytes_text(AUTH_REQUIRED_MESSAGE) + request = self._request - # Auth Phase - try: - msg = await self._wsock.receive(AUTH_MESSAGE_TIMEOUT) - except TimeoutError as err: - raise Disconnect( - f"Did not receive auth message within {AUTH_MESSAGE_TIMEOUT} seconds" - ) from err - - if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): - raise Disconnect("Received close message during auth phase") - - if msg.type is not WSMsgType.TEXT: - if msg.type is WSMsgType.ERROR: - # msg.data is the exception + if is_supervisor_unix_socket_request(request): + # Unix socket requests are pre-authenticated by the HTTP + # auth middleware — skip the token exchange. + connection = await auth.async_handle_supervisor_unix_socket() + else: + await send_bytes_text(AUTH_REQUIRED_MESSAGE) + + # Auth Phase + try: + msg = await self._wsock.receive(AUTH_MESSAGE_TIMEOUT) + except TimeoutError as err: raise Disconnect( - f"Received error message during auth phase: {msg.data}" + f"Did not receive auth message within {AUTH_MESSAGE_TIMEOUT} seconds" + ) from err + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + raise Disconnect("Received close message during auth phase") + + if msg.type is not WSMsgType.TEXT: + if msg.type is WSMsgType.ERROR: + # msg.data is the exception + raise Disconnect( + f"Received error message during auth phase: {msg.data}" + ) + raise Disconnect( + f"Received non-Text message of type {msg.type} during auth phase" ) - raise Disconnect( - f"Received non-Text message of type {msg.type} during auth phase" - ) - try: - auth_msg_data = json_loads(msg.data) - except ValueError as err: - raise Disconnect("Received invalid JSON during auth phase") from err + try: + auth_msg_data = json_loads(msg.data) + except ValueError as err: + raise Disconnect("Received invalid JSON during auth phase") from err + + if self._debug: + self._logger.debug("%s: Received %s", self.description, auth_msg_data) + connection = await auth.async_handle(auth_msg_data) - if self._debug: - self._logger.debug("%s: Received %s", self.description, auth_msg_data) - connection = await auth.async_handle(auth_msg_data) # As the webserver is now started before the start # event we do not want to block for websocket responses # diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py index ca66b8fef4be28..095ae8ad17a800 100644 --- a/tests/components/http/test_auth.py +++ b/tests/components/http/test_auth.py @@ -13,7 +13,7 @@ import pytest import yarl -from homeassistant.auth.const import GROUP_ID_READ_ONLY +from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY from homeassistant.auth.models import User from homeassistant.auth.providers import trusted_networks from homeassistant.auth.providers.homeassistant import HassAuthProvider @@ -32,6 +32,7 @@ current_request, setup_request_context, ) +from homeassistant.const import HASSIO_USER_NAME from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS from homeassistant.setup import async_setup_component @@ -658,3 +659,81 @@ async def test_create_user_once(hass: HomeAssistant) -> None: # test it did not create a user assert len(await hass.auth.async_get_users()) == cur_users + 1 + + +async def test_unix_socket_auth_with_supervisor_user( + hass: HomeAssistant, + app: web.Application, + aiohttp_client: ClientSessionGenerator, +) -> None: + """Test that Unix socket requests are authenticated as Supervisor user.""" + supervisor_user = await hass.auth.async_create_system_user( + HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN] + ) + await hass.auth.async_create_refresh_token(supervisor_user) + + await async_setup_auth(hass, app) + client = await aiohttp_client(app) + + with patch( + "homeassistant.components.http.auth.is_supervisor_unix_socket_request", + return_value=True, + ): + req = await client.get("/") + assert req.status == HTTPStatus.OK + data = await req.json() + assert data["user_id"] == supervisor_user.id + + +async def test_unix_socket_auth_without_supervisor_user( + hass: HomeAssistant, + app: web.Application, + aiohttp_client: ClientSessionGenerator, +) -> None: + """Test that Unix socket requests return 500 when no Supervisor user exists.""" + await async_setup_auth(hass, app) + client = await aiohttp_client(app) + + with patch( + "homeassistant.components.http.auth.is_supervisor_unix_socket_request", + return_value=True, + ): + req = await client.get("/") + assert req.status == HTTPStatus.INTERNAL_SERVER_ERROR + + +async def test_unix_socket_auth_caches_user_id( + hass: HomeAssistant, + app: web.Application, + aiohttp_client: ClientSessionGenerator, +) -> None: + """Test that Unix socket auth caches the Supervisor user ID.""" + supervisor_user = await hass.auth.async_create_system_user( + HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN] + ) + await hass.auth.async_create_refresh_token(supervisor_user) + + await async_setup_auth(hass, app) + client = await aiohttp_client(app) + + with patch( + "homeassistant.components.http.auth.is_supervisor_unix_socket_request", + return_value=True, + ): + # First request triggers user lookup + req = await client.get("/") + assert req.status == HTTPStatus.OK + + # Second request should use cached user ID + with ( + patch( + "homeassistant.components.http.auth.is_supervisor_unix_socket_request", + return_value=True, + ), + patch.object( + hass.auth, "async_get_users", wraps=hass.auth.async_get_users + ) as mock_get_users, + ): + req = await client.get("/") + assert req.status == HTTPStatus.OK + mock_get_users.assert_not_called() diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py index 945bc69dee6cf9..b27c3838caf6ae 100644 --- a/tests/components/http/test_ban.py +++ b/tests/components/http/test_ban.py @@ -466,3 +466,34 @@ async def unauth_handler(request): await manager.async_add_ban(remote_ip) assert m_open.call_count == 1 + + +async def test_unix_socket_skips_ban_check( + hass: HomeAssistant, aiohttp_client: ClientSessionGenerator +) -> None: + """Test that Unix socket requests bypass ban middleware.""" + app = web.Application() + app[KEY_HASS] = hass + setup_bans(hass, app, 5) + set_real_ip = mock_real_ip(app) + + with patch( + "homeassistant.components.http.ban.load_yaml_config_file", + return_value={ + banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS + }, + ): + client = await aiohttp_client(app) + + # Verify the IP is actually banned for normal requests + set_real_ip(BANNED_IPS[0]) + resp = await client.get("/") + assert resp.status == HTTPStatus.FORBIDDEN + + # Unix socket requests should bypass ban checks + with patch( + "homeassistant.components.http.ban.is_supervisor_unix_socket_request", + return_value=True, + ): + resp = await client.get("/") + assert resp.status == HTTPStatus.NOT_FOUND diff --git a/tests/components/http/test_init.py b/tests/components/http/test_init.py index 87701aba657143..67774d0eadd475 100644 --- a/tests/components/http/test_init.py +++ b/tests/components/http/test_init.py @@ -6,6 +6,7 @@ from http import HTTPStatus from ipaddress import ip_network import logging +import os from pathlib import Path from unittest.mock import ANY, Mock, patch @@ -14,6 +15,7 @@ from homeassistant.auth.providers.homeassistant import HassAuthProvider from homeassistant.components import cloud, http from homeassistant.components.cloud import CloudNotAvailable +from homeassistant.const import HASSIO_USER_NAME from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry as ir from homeassistant.helpers.http import KEY_HASS @@ -735,3 +737,74 @@ async def test_server_host( ) assert set(issue_registry.issues) == expected_issues + + +async def test_unix_socket_started_with_supervisor( + hass: HomeAssistant, + tmp_path: Path, +) -> None: + """Test unix socket is started when running under Supervisor.""" + await hass.auth.async_create_system_user( + HASSIO_USER_NAME, group_ids=["system-admin"] + ) + socket_path = tmp_path / "core.sock" + loop = asyncio.get_running_loop() + mock_sock = Mock() + with ( + patch.dict( + os.environ, {"SUPERVISOR_CORE_API_SOCKET": str(socket_path)}, clear=False + ), + patch("asyncio.BaseEventLoop.create_server", return_value=Mock()), + patch( + "homeassistant.components.http.web_runner.HomeAssistantUnixSite" + "._create_unix_socket", + return_value=mock_sock, + ) as mock_create_sock, + patch.object( + loop, "create_unix_server", return_value=Mock() + ) as mock_create_unix, + ): + assert await async_setup_component(hass, "http", {"http": {}}) + await hass.async_start() + await hass.async_block_till_done() + + mock_create_sock.assert_called_once() + mock_create_unix.assert_called_once_with(ANY, sock=mock_sock, backlog=128) + assert hass.http.supervisor_site is not None + + +async def test_unix_socket_not_started_without_supervisor( + hass: HomeAssistant, +) -> None: + """Test unix socket is not started when not running under Supervisor.""" + with ( + patch.dict(os.environ, {}, clear=False), + patch("asyncio.BaseEventLoop.create_server", return_value=Mock()), + ): + os.environ.pop("SUPERVISOR_CORE_API_SOCKET", None) + assert await async_setup_component(hass, "http", {"http": {}}) + await hass.async_start() + await hass.async_block_till_done() + + assert hass.http.supervisor_site is None + + +async def test_unix_socket_rejected_relative_path( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test unix socket is rejected when path is relative.""" + with ( + patch.dict( + os.environ, + {"SUPERVISOR_CORE_API_SOCKET": "relative/path.sock"}, + clear=False, + ), + patch("asyncio.BaseEventLoop.create_server", return_value=Mock()), + ): + assert await async_setup_component(hass, "http", {"http": {}}) + await hass.async_start() + await hass.async_block_till_done() + + assert hass.http.supervisor_site is None + assert "path must be absolute" in caplog.text diff --git a/tests/components/websocket_api/test_auth.py b/tests/components/websocket_api/test_auth.py index 49ee593fed7eef..09ff36c4ce02b5 100644 --- a/tests/components/websocket_api/test_auth.py +++ b/tests/components/websocket_api/test_auth.py @@ -18,6 +18,7 @@ SIGNAL_WEBSOCKET_DISCONNECTED, URL, ) +from homeassistant.const import HASSIO_USER_NAME from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.setup import async_setup_component @@ -367,3 +368,43 @@ async def test_error_right_after_auth_disconnects( assert close_error_msg.type is WSMsgType.CLOSE assert "Received error message during command phase: explode" in caplog.text + + +async def test_unix_socket_auth_bypass( + hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator +) -> None: + """Test that Unix socket connections skip websocket auth phase.""" + # Create the Supervisor system user + await hass.auth.async_create_system_user( + HASSIO_USER_NAME, group_ids=["system-admin"] + ) + + assert await async_setup_component(hass, "websocket_api", {}) + await hass.async_block_till_done() + + client = await hass_client_no_auth() + + with ( + patch( + "homeassistant.components.http.ban.is_supervisor_unix_socket_request", + return_value=True, + ), + patch( + "homeassistant.components.http.auth.is_supervisor_unix_socket_request", + return_value=True, + ), + patch( + "homeassistant.components.websocket_api.http.is_supervisor_unix_socket_request", + return_value=True, + ), + ): + async with client.ws_connect(URL) as ws: + # Should immediately receive auth_ok without sending a token + auth_msg = await ws.receive_json() + assert auth_msg["type"] == TYPE_AUTH_OK + + # Verify the connection works by sending a ping + await ws.send_json({"id": 1, "type": "ping"}) + pong_msg = await ws.receive_json() + assert pong_msg["type"] == "pong" + assert pong_msg["id"] == 1