From 8d6be2b83bcf26681f2527152fd08da6988321dc Mon Sep 17 00:00:00 2001 From: Oleg A Date: Mon, 12 Feb 2024 23:02:32 +0300 Subject: [PATCH 1/3] feat: support timeout --- aiohttp_sse/__init__.py | 31 +++++++++++++++++++++++---- tests/test_sse.py | 46 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/aiohttp_sse/__init__.py b/aiohttp_sse/__init__.py index ea36987..0a90d90 100644 --- a/aiohttp_sse/__init__.py +++ b/aiohttp_sse/__init__.py @@ -39,6 +39,7 @@ def __init__( reason: Optional[str] = None, headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, + timeout: Optional[float] = None, ): super().__init__(status=status, reason=reason) @@ -54,6 +55,7 @@ def __init__( self._ping_interval: float = self.DEFAULT_PING_INTERVAL self._ping_task: Optional[asyncio.Task[None]] = None self._sep = sep if sep is not None else self.DEFAULT_SEPARATOR + self._timeout = timeout def is_connected(self) -> bool: """Check connection is prepared and ping task is not done.""" @@ -130,10 +132,16 @@ async def send( buffer.write(self._sep) try: - await self.write(buffer.getvalue().encode("utf-8")) + await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout + self.write(buffer.getvalue().encode("utf-8")), + timeout=self._timeout, + ) except ConnectionResetError: self.stop_streaming() raise + except asyncio.TimeoutError: + self.stop_streaming() + raise TimeoutError async def wait(self) -> None: """EventSourceResponse object is used for streaming data to the client, @@ -202,8 +210,16 @@ async def _ping(self) -> None: while True: await asyncio.sleep(self._ping_interval) try: - await self.write(message) - except (ConnectionResetError, RuntimeError): + await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout + self.write(message), + timeout=self._timeout, + ) + except ( + ConnectionResetError, + RuntimeError, + TimeoutError, + asyncio.TimeoutError, + ): # RuntimeError - on writing after EOF break @@ -256,6 +272,7 @@ def sse_response( headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, response_cls: Type[EventSourceResponse] = EventSourceResponse, + timeout: Optional[float] = None, ) -> Any: if not issubclass(response_cls, EventSourceResponse): raise TypeError( @@ -263,5 +280,11 @@ def sse_response( "aiohttp_sse.EventSourceResponse, got {}".format(response_cls) ) - sse = response_cls(status=status, reason=reason, headers=headers, sep=sep) + sse = response_cls( + status=status, + reason=reason, + headers=headers, + sep=sep, + timeout=timeout, + ) return _ContextManager(sse._prepare(request)) diff --git a/tests/test_sse.py b/tests/test_sse.py index 466df76..3151c63 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -1,6 +1,6 @@ import asyncio import sys -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable, List, Optional import pytest from aiohttp import web @@ -559,3 +559,47 @@ async def handler(request: web.Request) -> EventSourceResponse: async with client.get("/") as response: assert 200 == response.status + + +@pytest.mark.parametrize("timeout", (None, 0.1)) +async def test_with_timeout( + aiohttp_client: ClientFixture, + monkeypatch: pytest.MonkeyPatch, + timeout: Optional[float], +) -> None: + """Test write timeout. + + Relates to this issue: + https://github.com/sysid/sse-starlette/issues/89 + """ + timeout_raised = False + + async def frozen_write(_data: bytes) -> None: + await asyncio.sleep(42) + + async def handler(request: web.Request) -> EventSourceResponse: + sse = EventSourceResponse(timeout=timeout) + sse.ping_interval = 42 + await sse.prepare(request) + monkeypatch.setattr(sse, "write", frozen_write) + + async with sse: + try: + await sse.send("foo") + except TimeoutError: + nonlocal timeout_raised + timeout_raised = True + raise + + return sse + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + async with client.get("/") as resp: + assert resp.status == 200 + await asyncio.sleep(0.5) + assert resp.connection.closed is bool(timeout) + + assert timeout_raised is bool(timeout) From b0e93c54bd32a84cd728605087ebc70a11d3d69f Mon Sep 17 00:00:00 2001 From: Oleg A Date: Tue, 13 Feb 2024 00:00:03 +0300 Subject: [PATCH 2/3] chore: mypy + coverage --- tests/test_sse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sse.py b/tests/test_sse.py index 3151c63..dbbfdcf 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -591,7 +591,7 @@ async def handler(request: web.Request) -> EventSourceResponse: timeout_raised = True raise - return sse + return sse # pragma: no cover app = web.Application() app.router.add_route("GET", "/", handler) @@ -600,6 +600,6 @@ async def handler(request: web.Request) -> EventSourceResponse: async with client.get("/") as resp: assert resp.status == 200 await asyncio.sleep(0.5) - assert resp.connection.closed is bool(timeout) + assert resp.connection and resp.connection.closed is bool(timeout) assert timeout_raised is bool(timeout) From 26251da9f57fc2d744a44d613f172802bf38162e Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Thu, 25 Apr 2024 20:23:01 +0100 Subject: [PATCH 3/3] Tweak test --- tests/test_sse.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/tests/test_sse.py b/tests/test_sse.py index dbbfdcf..52e2cf9 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -567,29 +567,24 @@ async def test_with_timeout( monkeypatch: pytest.MonkeyPatch, timeout: Optional[float], ) -> None: - """Test write timeout. - - Relates to this issue: - https://github.com/sysid/sse-starlette/issues/89 - """ + """Test that a timeout occurs when client is not reading responses.""" timeout_raised = False - - async def frozen_write(_data: bytes) -> None: - await asyncio.sleep(42) + should_raise_timeout = timeout is not None async def handler(request: web.Request) -> EventSourceResponse: sse = EventSourceResponse(timeout=timeout) - sse.ping_interval = 42 await sse.prepare(request) - monkeypatch.setattr(sse, "write", frozen_write) async with sse: - try: - await sse.send("foo") - except TimeoutError: - nonlocal timeout_raised - timeout_raised = True - raise + while True: + # .send() only yields if socket is full, so yield here to run client. + await asyncio.sleep(0) + try: + await sse.send("x" * 10000000) # Enough data to fill socket + except TimeoutError: + nonlocal timeout_raised + timeout_raised = True + break return sse # pragma: no cover @@ -600,6 +595,4 @@ async def handler(request: web.Request) -> EventSourceResponse: async with client.get("/") as resp: assert resp.status == 200 await asyncio.sleep(0.5) - assert resp.connection and resp.connection.closed is bool(timeout) - - assert timeout_raised is bool(timeout) + assert timeout_raised is should_raise_timeout