diff --git a/aiohttp_sse/__init__.py b/aiohttp_sse/__init__.py index a28b3c9..770d0d5 100644 --- a/aiohttp_sse/__init__.py +++ b/aiohttp_sse/__init__.py @@ -40,6 +40,7 @@ def __init__( reason: Optional[str] = None, headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, + timeout: Optional[float] = None, ): super().__init__(status=status, reason=reason) @@ -55,6 +56,7 @@ def __init__( self._ping_interval: float = self.DEFAULT_PING_INTERVAL self._ping_task: Optional[asyncio.Task[None]] = None self._sep = sep if sep is not None else self.DEFAULT_SEPARATOR + self._timeout = timeout def is_connected(self) -> bool: """Check connection is prepared and ping task is not done.""" @@ -131,10 +133,16 @@ async def send( buffer.write(self._sep) try: - await self.write(buffer.getvalue().encode("utf-8")) + await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout + self.write(buffer.getvalue().encode("utf-8")), + timeout=self._timeout, + ) except ConnectionResetError: self.stop_streaming() raise + except asyncio.TimeoutError: + self.stop_streaming() + raise TimeoutError async def wait(self) -> None: """EventSourceResponse object is used for streaming data to the client, @@ -205,8 +213,16 @@ async def _ping(self) -> None: while True: await asyncio.sleep(self._ping_interval) try: - await self.write(message) - except (ConnectionResetError, RuntimeError): + await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout + self.write(message), + timeout=self._timeout, + ) + except ( + ConnectionResetError, + RuntimeError, + TimeoutError, + asyncio.TimeoutError, + ): # RuntimeError - on writing after EOF break @@ -259,6 +275,7 @@ def sse_response( headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, response_cls: type[EventSourceResponse] = EventSourceResponse, + timeout: Optional[float] = None, ) -> Any: if not issubclass(response_cls, EventSourceResponse): raise TypeError( @@ -266,5 +283,11 @@ def sse_response( "aiohttp_sse.EventSourceResponse, got {}".format(response_cls) ) - sse = response_cls(status=status, reason=reason, headers=headers, sep=sep) + sse = response_cls( + status=status, + reason=reason, + headers=headers, + sep=sep, + timeout=timeout, + ) return _ContextManager(sse._prepare(request)) diff --git a/tests/test_sse.py b/tests/test_sse.py index f46909b..33c297b 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -1,5 +1,6 @@ import asyncio import sys +from typing import Optional import pytest from aiohttp import web @@ -557,3 +558,40 @@ async def handler(request: web.Request) -> EventSourceResponse: async with client.get("/") as response: assert 200 == response.status + + +@pytest.mark.parametrize("timeout", (None, 0.1)) +async def test_with_timeout( + aiohttp_client: AiohttpClient, + monkeypatch: pytest.MonkeyPatch, + timeout: Optional[float], +) -> None: + """Test that a timeout occurs when client is not reading responses.""" + timeout_raised = False + should_raise_timeout = timeout is not None + + async def handler(request: web.Request) -> EventSourceResponse: + sse = EventSourceResponse(timeout=timeout) + await sse.prepare(request) + + async with sse: + while True: + # .send() only yields if socket is full, so yield here to run client. + await asyncio.sleep(0) + try: + await sse.send("x" * 10000000) # Enough data to fill socket + except TimeoutError: + nonlocal timeout_raised + timeout_raised = True + break + + return sse # pragma: no cover + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + async with client.get("/") as resp: + assert resp.status == 200 + await asyncio.sleep(0.5) + assert timeout_raised is should_raise_timeout