diff --git a/CHANGES/11761.bugfix.rst b/CHANGES/11761.bugfix.rst new file mode 100644 index 00000000000..d4661c6d4a1 --- /dev/null +++ b/CHANGES/11761.bugfix.rst @@ -0,0 +1,4 @@ +Fixed ``AssertionError`` when the transport is ``None`` during WebSocket +preparation or file response sending (e.g. when a client disconnects +immediately after connecting). A ``ConnectionResetError`` is now raised +instead -- by :user:`agners`. diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index eeaa2010f98..f339bec9662 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -128,7 +128,8 @@ async def _sendfile( loop = request._loop transport = request.transport - assert transport is not None + if transport is None: + raise ConnectionResetError("Connection lost") try: await loop.sendfile(transport, fobj, offset, count) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index ca129bb0f30..2aeeb6dec1f 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -361,7 +361,8 @@ def _pre_start(self, request: BaseRequest) -> tuple[str | None, WebSocketWriter] self.force_close() self._compress = compress transport = request._protocol.transport - assert transport is not None + if transport is None: + raise ConnectionResetError("Connection lost") writer = WebSocketWriter( request._protocol, transport, diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 87be2db182b..1d695d332c4 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -1,5 +1,6 @@ import asyncio import bz2 +import contextlib import gzip import pathlib import socket @@ -15,6 +16,7 @@ from aiohttp.compression_utils import ZLibBackend from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.typedefs import PathLike +from aiohttp.web_fileresponse import NOSENDFILE try: import brotlicffi as brotli @@ -1156,3 +1158,64 @@ async def handler(request: web.Request) -> web.FileResponse: resp.release() await client.close() + + +@pytest.mark.skipif(NOSENDFILE, reason="OS sendfile not available") +async def test_sendfile_after_client_disconnect( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: + """Test ConnectionResetError when client disconnects before sendfile. + + Reproduces the race condition where: + - Client sends a GET request for a file + - Handler does async work (e.g. auth check) before returning a FileResponse + - Client disconnects while the handler is busy + - Server then calls sendfile() → ConnectionResetError (not AssertionError) + + _send_headers_immediately is set to False so that super().prepare() + only buffers the headers without writing to the transport. Otherwise + _write() raises ClientConnectionResetError first and _sendfile()'s own + transport check is never reached. + """ + filepath = tmp_path / "test.txt" + filepath.write_bytes(b"x" * 1024) + + handler_started = asyncio.Event() + prepare_done = asyncio.Event() + captured_protocol = None + + async def handler(request: web.Request) -> web.Response: + nonlocal captured_protocol + resp = web.FileResponse(filepath) + resp._send_headers_immediately = False + captured_protocol = request._protocol + handler_started.set() + # Simulate async work (e.g., auth check) during which client disconnects. + await asyncio.sleep(0) + with pytest.raises(ConnectionResetError, match="Connection lost"): + await resp.prepare(request) + prepare_done.set() + return web.Response(status=503) + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + request_task = asyncio.create_task(client.get("/")) + + # Wait until the handler is running but has not yet returned the response. + await handler_started.wait() + assert captured_protocol is not None + + # Simulate the client disconnecting by setting transport to None directly. + # We cannot use force_close() because closing the TCP transport triggers + # connection_lost() which cancels the handler task before it can call + # prepare() and hit the ConnectionResetError in _sendfile(). + captured_protocol.transport = None + + # Wait for the handler to resume, call prepare(), and hit ConnectionResetError. + await asyncio.wait_for(prepare_done.wait(), timeout=1) + + request_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await request_task diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 1e202649c6a..7257c47ba73 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -11,7 +11,7 @@ import pytest import aiohttp -from aiohttp import WSServerHandshakeError, web +from aiohttp import WSServerHandshakeError, hdrs, web from aiohttp.http import WSCloseCode, WSMsgType from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer @@ -1661,3 +1661,58 @@ async def websocket_handler( assert msg.type is aiohttp.WSMsgType.TEXT assert msg.data == "success" await ws.close() + + +async def test_prepare_after_client_disconnect(aiohttp_client: AiohttpClient) -> None: + """Test ConnectionResetError when client disconnects before ws.prepare(). + + Reproduces the race condition where: + - Client connects and sends a WebSocket upgrade request + - Handler starts async work (e.g. authentication) before calling ws.prepare() + - Client disconnects while the handler is busy + - Handler then calls ws.prepare() → ConnectionResetError (not AssertionError) + """ + handler_started = asyncio.Event() + captured_protocol = None + + async def handler(request: web.Request) -> web.Response: + nonlocal captured_protocol + ws = web.WebSocketResponse() + captured_protocol = request._protocol + handler_started.set() + # Simulate async work (e.g., auth check) during which client disconnects. + await asyncio.sleep(0) + with pytest.raises(ConnectionResetError, match="Connection lost"): + await ws.prepare(request) + return web.Response(status=503) + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + request_task = asyncio.create_task( + client.session.get( + client.make_url("/"), + headers={ + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "Upgrade", + hdrs.SEC_WEBSOCKET_KEY: "dGhlIHNhbXBsZSBub25jZQ==", + hdrs.SEC_WEBSOCKET_VERSION: "13", + }, + ) + ) + + # Wait until the handler is running but has not yet called ws.prepare(). + await handler_started.wait() + assert captured_protocol is not None + + # Simulate the client disconnecting abruptly. + captured_protocol.force_close() + + # Yield so the handler can resume and hit the ConnectionResetError. + await asyncio.sleep(0) + + with contextlib.suppress( + aiohttp.ServerDisconnectedError, aiohttp.ClientConnectionResetError + ): + await request_task