diff --git a/src/chat_sdk/adapters/google_chat/adapter.py b/src/chat_sdk/adapters/google_chat/adapter.py index 10503b3..8d49d43 100644 --- a/src/chat_sdk/adapters/google_chat/adapter.py +++ b/src/chat_sdk/adapters/google_chat/adapter.py @@ -68,6 +68,7 @@ EphemeralMessage, FetchOptions, FetchResult, + FileUpload, FormattedContent, ListThreadsOptions, ListThreadsResult, @@ -430,6 +431,72 @@ async def _gchat_api_request( ) return result + # ========================================================================= + # Media upload + # ========================================================================= + + async def _gchat_media_upload( + self, + files: list[FileUpload], + space_name: str, + thread_name: str | None, + label: str, + ) -> None: + """Upload files to a GChat space via the multipart media upload endpoint. + + One request is sent per file -- the GChat upload API does not support batching. + Per-file errors are logged but never raised so a transport hiccup cannot kill + a turn whose text response already landed. + """ + if not files: + return + + token = await self._get_access_token() + session = await self._get_http_session() + upload_url = f"https://chat.googleapis.com/upload/v1/{space_name}/messages" + + for file in files: + try: + boundary = f"boundary_{_random_id()}" + metadata: dict[str, Any] = {"text": label} + if thread_name: + metadata["thread"] = {"name": thread_name} + + body = ( + f"--{boundary}\r\n" + f"Content-Type: application/json; charset=UTF-8\r\n\r\n" + f"{json.dumps(metadata)}\r\n" + f"--{boundary}\r\n" + f"Content-Type: {file.mime_type or 'application/octet-stream'}\r\n\r\n" + ).encode() + file.data + f"\r\n--{boundary}--".encode() + + async with session.request( + "POST", + upload_url, + data=body, + params={"uploadType": "multipart"}, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": f"multipart/related; boundary={boundary}", + }, + ) as response: + if response.status >= 400: + error_text = await response.text() + self._logger.error( + f"GChat media upload failed for {file.filename}", + {"status": response.status, "error": error_text}, + ) + else: + self._logger.debug( + "GChat media upload succeeded", + {"filename": file.filename}, + ) + except Exception: + self._logger.error( + f"GChat media upload failed for {file.filename}", + {"exc_info": True}, + ) + # ========================================================================= # Lifecycle # ========================================================================= @@ -1346,13 +1413,22 @@ async def post_message( thread_name = decoded.thread_name try: - # Check for files - currently not implemented for GChat files = extract_files(message) if files: - self._logger.warn( - "File uploads are not yet supported for Google Chat. Files will be ignored.", - {"fileCount": len(files)}, + await self._gchat_media_upload(files, space_name, thread_name, "") + has_text = ( + isinstance(message, str) + or (hasattr(message, "raw") and getattr(message, "raw", None)) + or (hasattr(message, "markdown") and getattr(message, "markdown", None)) + or (hasattr(message, "ast") and getattr(message, "ast", None)) ) + card = extract_card(message) + if not (has_text or card): + return RawMessage( + id=f"file-{int(time.time() * 1000)}", + thread_id=thread_id, + raw={"files": files}, + ) # Check if message contains a card card = extract_card(message) diff --git a/tests/test_google_chat_adapter.py b/tests/test_google_chat_adapter.py index c726af0..1de87e9 100644 --- a/tests/test_google_chat_adapter.py +++ b/tests/test_google_chat_adapter.py @@ -284,3 +284,138 @@ def test_is_trusted_gchat_download_url_allowlist(self): assert not GoogleChatAdapter._is_trusted_gchat_download_url("https://attacker.example/x") # Rejects look-alikes assert not GoogleChatAdapter._is_trusted_gchat_download_url("https://chat.googleapis.com.attacker.tld/x") + + +# --------------------------------------------------------------------------- +# _gchat_media_upload +# --------------------------------------------------------------------------- + + +class TestGChatMediaUpload: + """GChat file delivery via the media upload endpoint, exercised through post_message.""" + + @staticmethod + def _make_fake_session(): + """Return (session, calls_list). calls_list accumulates every request kwargs dict.""" + calls: list[dict] = [] + + class _FakeResponse: + status = 200 + + async def json(self) -> dict: + return {"name": "spaces/TEST/messages/1"} + + async def text(self) -> str: + return "" + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + class _FakeSession: + def request(self, method: str, url: str, **kwargs) -> _FakeResponse: + calls.append({"method": method, "url": url, **kwargs}) + return _FakeResponse() + + return _FakeSession(), calls + + @pytest.mark.asyncio + async def test_post_message_sends_one_multipart_upload_per_file(self): + """Two files produce two POSTs to the GChat media upload endpoint with + uploadType=multipart -- the upload API does not support batching.""" + from unittest.mock import AsyncMock + + from chat_sdk.adapters.google_chat.thread_utils import GoogleChatThreadId, encode_thread_id + from chat_sdk.types import FileUpload, PostableMarkdown + + adapter = _make_adapter() + adapter._get_access_token = AsyncMock(return_value="tok") # type: ignore[method-assign] + session, calls = self._make_fake_session() + adapter._get_http_session = AsyncMock(return_value=session) # type: ignore[method-assign] + + thread_id = encode_thread_id(GoogleChatThreadId(space_name="spaces/TEST")) + message = PostableMarkdown( + markdown="", + files=[ + FileUpload(data=b"csv", filename="data.csv", mime_type="text/csv"), + FileUpload(data=b"\x89PNG", filename="chart.png", mime_type="image/png"), + ], + ) + + await adapter.post_message(thread_id, message) + + upload_calls = [call for call in calls if "/upload/v1/" in call["url"]] + assert len(upload_calls) == 2, ( + f"expected one media upload call per file; got {len(upload_calls)}. " + "GChat's upload endpoint does not batch -- see _gchat_media_upload in google_chat/adapter.py." + ) + for call in upload_calls: + assert call.get("params", {}).get("uploadType") == "multipart", ( + "uploadType=multipart query param required for GChat media upload; " + f"got params={call.get('params')}" + ) + + @pytest.mark.asyncio + async def test_files_only_post_message_does_not_emit_text_message(self): + """post_message with files and no text/card returns early after uploading -- + a separate empty text message must not be posted to the standard endpoint.""" + from unittest.mock import AsyncMock + + from chat_sdk.adapters.google_chat.thread_utils import GoogleChatThreadId, encode_thread_id + from chat_sdk.types import FileUpload, PostableMarkdown + + adapter = _make_adapter() + adapter._get_access_token = AsyncMock(return_value="tok") # type: ignore[method-assign] + session, calls = self._make_fake_session() + adapter._get_http_session = AsyncMock(return_value=session) # type: ignore[method-assign] + + thread_id = encode_thread_id(GoogleChatThreadId(space_name="spaces/TEST")) + message = PostableMarkdown( + markdown="", + files=[FileUpload(data=b"hello", filename="report.txt", mime_type="text/plain")], + ) + + result = await adapter.post_message(thread_id, message) + + standard_api_calls = [call for call in calls if "/upload/v1/" not in call["url"]] + assert standard_api_calls == [], ( + "post_message must return early after file upload when there is no text or card -- " + "an empty text message alongside files is unwanted. " + f"unexpected standard-API calls: {standard_api_calls}" + ) + assert result is not None + + @pytest.mark.asyncio + async def test_post_message_file_upload_failure_does_not_raise(self): + """A file upload network error must not propagate out of post_message -- + a transport hiccup cannot kill a turn whose text response already landed.""" + from unittest.mock import AsyncMock + + from chat_sdk.adapters.google_chat.thread_utils import GoogleChatThreadId, encode_thread_id + from chat_sdk.types import FileUpload, PostableMarkdown + + adapter = _make_adapter() + adapter._get_access_token = AsyncMock(return_value="tok") # type: ignore[method-assign] + + class _FailResponse: + async def __aenter__(self): + raise RuntimeError("network error") + + async def __aexit__(self, *args): + pass + + class _FailSession: + def request(self, method: str, url: str, **kwargs) -> _FailResponse: + return _FailResponse() + + adapter._get_http_session = AsyncMock(return_value=_FailSession()) # type: ignore[method-assign] + + thread_id = encode_thread_id(GoogleChatThreadId(space_name="spaces/TEST")) + message = PostableMarkdown( + markdown="", + files=[FileUpload(data=b"x", filename="file.txt", mime_type="text/plain")], + ) + + await adapter.post_message(thread_id, message)