diff --git a/src/chat_sdk/adapters/teams/adapter.py b/src/chat_sdk/adapters/teams/adapter.py index 33bc85f..2393a0b 100644 --- a/src/chat_sdk/adapters/teams/adapter.py +++ b/src/chat_sdk/adapters/teams/adapter.py @@ -16,7 +16,7 @@ import re from collections.abc import Awaitable, Callable from datetime import datetime, timezone -from typing import Any, Literal, NoReturn +from typing import Any, Literal, NoReturn, cast from urllib.parse import urlparse from chat_sdk.adapters.teams.cards import card_to_adaptive_card @@ -24,6 +24,8 @@ from chat_sdk.adapters.teams.types import ( TeamsAdapterConfig, TeamsChannelContext, + TeamsDmContext, + TeamsGraphContext, TeamsThreadId, ) from chat_sdk.emoji import convert_emoji_placeholders @@ -61,6 +63,11 @@ MESSAGEID_CAPTURE_PATTERN = re.compile(r"messageid=(\d+)") MESSAGEID_STRIP_PATTERN = re.compile(r";messageid=\d+") +# AAD object IDs are GUIDs (8-4-4-4-12 hex). Used to gate ``aadObjectId`` +# values from incoming activities before formatting them into Microsoft +# Graph chat IDs (vercel/chat#403). See ``_cache_user_context`` and +# ``_chat_id_from_context``. +_AAD_OBJECT_ID_PATTERN = re.compile(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}") CACHE_TTL_MS = 30 * 24 * 60 * 60 * 1000 # 30 days # Allowed Microsoft Bot Framework service URL patterns (SSRF protection). @@ -308,12 +315,44 @@ async def _cache_user_context(self, activity: dict[str, Any]) -> None: base_channel_id = MESSAGEID_STRIP_PATTERN.sub("", conversation_id) if team_aad_group_id and channel_data.get("channel", {}).get("id") and state: + # Wire-shape parity with upstream TS (#403): the channel branch + # omits the discriminator. ``_chat_id_from_context`` and + # ``_get_graph_context`` treat ``type != "dm"`` as channel, so + # the missing key is unambiguous. context: TeamsChannelContext = { "team_id": team_aad_group_id, "channel_id": channel_data["channel"]["id"], } await state.set(f"teams:channelContext:{base_channel_id}", json.dumps(context), ttl) + # Cache DM context for Microsoft Graph chat ID resolution + # (vercel/chat#403). Bot Framework hands out opaque DM conversation + # IDs that Graph's ``/chats/{chat-id}/messages`` endpoint rejects; + # the canonical Graph chat ID for a 1:1 DM is + # ``19:{userAadId}_{botId}@unq.gbl.spaces``. ``aadObjectId`` is + # only present for real Teams users (not bots), and DM conversation + # IDs do not start with ``19:`` (channel/group chats do). + # + # Defense-in-depth: AAD object IDs are GUIDs (8-4-4-4-12 hex). Bot + # Framework JWT verification authenticates the activity envelope + # but does not constrain ``from.aadObjectId``; a malformed value + # could otherwise inject ``/`` / ``?`` / ``#`` into the Graph URL + # path and cause a misrouted request. Reject anything that doesn't + # match the GUID shape before formatting it into the chat ID. + aad_object_id = from_user.get("aadObjectId") + if ( + aad_object_id + and self._app_id + and not base_channel_id.startswith("19:") + and state + and _AAD_OBJECT_ID_PATTERN.fullmatch(aad_object_id) + ): + dm_context: TeamsDmContext = { + "graph_chat_id": f"19:{aad_object_id}_{self._app_id}@unq.gbl.spaces", + "type": "dm", + } + await state.set(f"teams:channelContext:{base_channel_id}", json.dumps(dm_context), ttl) + async def _handle_message_activity( self, activity: dict[str, Any], @@ -1036,7 +1075,13 @@ async def fetch_messages( thread_message_id = message_id_match.group(1) if message_id_match else None base_conversation_id = MESSAGEID_STRIP_PATTERN.sub("", conversation_id) - channel_context = await self._get_channel_context(base_conversation_id) if thread_message_id else None + # vercel/chat#403: always look up the Graph context, not just for + # channel threads — DMs need it to map the opaque Bot Framework + # conversation ID to the canonical Graph chat ID + # (``19:{aadId}_{botId}@unq.gbl.spaces``) that ``/chats/{id}`` + # accepts. + graph_context = await self._get_graph_context(base_conversation_id) + context_type: str | None = graph_context.get("type") if graph_context else None try: self._logger.debug( @@ -1044,14 +1089,16 @@ async def fetch_messages( { "conversationId": base_conversation_id, "threadMessageId": thread_message_id, - "hasChannelContext": channel_context is not None, + "contextType": context_type or "none", "limit": limit, "cursor": cursor, "direction": direction, }, ) - if channel_context and thread_message_id: + if graph_context and context_type != "dm" and thread_message_id: + # Narrowed: channel context for a channel thread. + channel_context = cast(TeamsChannelContext, graph_context) return await self._fetch_channel_thread_messages( channel_context, thread_message_id, @@ -1059,30 +1106,10 @@ async def fetch_messages( options, ) - graph_messages: list[dict[str, Any]] - has_more = False + chat_id = self._chat_id_from_context(graph_context, base_conversation_id) + graph_messages, has_more = await self._paginate_graph_chat_messages(chat_id, limit, direction, cursor) - if direction == "forward": - params: dict[str, Any] = { - "$top": limit, - "$orderby": "createdDateTime asc", - } - if cursor: - params["$filter"] = f"createdDateTime gt {cursor}" - graph_messages = await self._graph_list_chat_messages(base_conversation_id, params) - has_more = len(graph_messages) >= limit - else: - params = { - "$top": limit, - "$orderby": "createdDateTime desc", - } - if cursor: - params["$filter"] = f"createdDateTime lt {cursor}" - graph_messages = await self._graph_list_chat_messages(base_conversation_id, params) - graph_messages.reverse() - has_more = len(graph_messages) >= limit - - if thread_message_id and not channel_context: + if thread_message_id and not graph_context: graph_messages = [msg for msg in graph_messages if msg.get("id") and msg["id"] >= thread_message_id] self._logger.debug( "Filtered group chat messages to thread", @@ -1134,13 +1161,14 @@ async def fetch_channel_messages( direction = options.direction or "backward" try: - channel_context = await self._get_channel_context(base_conversation_id) + graph_context = await self._get_graph_context(base_conversation_id) + context_type = graph_context.get("type") if graph_context else None self._logger.debug( "Teams Graph API: fetchChannelMessages", { "conversationId": base_conversation_id, - "hasChannelContext": channel_context is not None, + "contextType": context_type or "none", "limit": limit, "direction": direction, }, @@ -1149,7 +1177,8 @@ async def fetch_channel_messages( graph_messages: list[dict[str, Any]] has_more = False - if channel_context: + if graph_context and context_type != "dm": + channel_context = cast(TeamsChannelContext, graph_context) if direction == "forward": graph_messages = await self._graph_list_channel_messages( channel_context["team_id"], @@ -1176,19 +1205,14 @@ async def fetch_channel_messages( ) graph_messages.reverse() has_more = len(graph_messages) >= limit - elif direction == "forward": - params = {"$top": limit, "$orderby": "createdDateTime asc"} - if options.cursor: - params["$filter"] = f"createdDateTime gt {options.cursor}" - graph_messages = await self._graph_list_chat_messages(base_conversation_id, params) - has_more = len(graph_messages) >= limit else: - params = {"$top": limit, "$orderby": "createdDateTime desc"} - if options.cursor: - params["$filter"] = f"createdDateTime lt {options.cursor}" - graph_messages = await self._graph_list_chat_messages(base_conversation_id, params) - graph_messages.reverse() - has_more = len(graph_messages) >= limit + # vercel/chat#403: DM contexts substitute the canonical Graph + # chat ID for the opaque Bot Framework conversation ID; no + # context (group chat) falls through to the raw ID. + chat_id = self._chat_id_from_context(graph_context, base_conversation_id) + graph_messages, has_more = await self._paginate_graph_chat_messages( + chat_id, limit, direction, options.cursor + ) messages = [self._map_graph_message(msg, channel_id) for msg in graph_messages if msg.get("id")] @@ -1225,7 +1249,14 @@ async def fetch_channel_info(self, channel_id: str) -> ChannelInfo: base_conversation_id = MESSAGEID_STRIP_PATTERN.sub("", conversation_id) is_dm = not conversation_id.startswith("19:") - channel_context = await self._get_channel_context(base_conversation_id) if not is_dm else None + # vercel/chat#403: only call into the Graph teams/channels + # endpoint for true channel contexts. A cached DM context (now + # possible when ``aadObjectId`` was present on the activity) + # must not be treated as a channel. + graph_context = await self._get_graph_context(base_conversation_id) if not is_dm else None + channel_context: TeamsChannelContext | None = None + if graph_context and graph_context.get("type") != "dm": + channel_context = cast(TeamsChannelContext, graph_context) if channel_context: try: @@ -1344,8 +1375,19 @@ async def disconnect(self) -> None: # Graph API — internal helpers # ========================================================================= - async def _get_channel_context(self, base_conversation_id: str) -> TeamsChannelContext | None: - """Look up cached channel context (team_id, channel_id) for a conversation.""" + async def _get_graph_context(self, base_conversation_id: str) -> TeamsGraphContext | None: + """Look up cached Microsoft Graph context for a conversation. + + Returns either a :class:`TeamsChannelContext` (channel/team + thread) or a :class:`TeamsDmContext` (1:1 DM with a resolved + Graph chat ID). For group chats, no entry is cached — the raw + conversation ID works as-is with Graph's ``/chats`` endpoints. + + Backwards compat: cached entries written before vercel/chat#403 + omit the ``type`` discriminator and are treated as + ``"channel"`` by :meth:`_chat_id_from_context` and the call + sites that branch on context type. + """ if not self._chat: return None state = self._chat.get_state() @@ -1359,6 +1401,21 @@ async def _get_channel_context(self, base_conversation_id: str) -> TeamsChannelC pass return None + @staticmethod + def _chat_id_from_context( + context: TeamsGraphContext | None, + base_conversation_id: str, + ) -> str: + """Resolve the Microsoft Graph chat ID for a non-channel conversation. + + Uses the DM context's ``graph_chat_id`` when present, otherwise + falls back to the raw Bot Framework conversation ID (which works + as-is for group chats and the legacy pre-#403 cache shape). + """ + if context is not None and context.get("type") == "dm": + return context["graph_chat_id"] + return base_conversation_id + async def _graph_list_chat_messages( self, chat_id: str, @@ -1380,6 +1437,28 @@ async def _graph_list_chat_messages( data = await response.json() return data.get("value", []) + async def _paginate_graph_chat_messages( + self, + chat_id: str, + limit: int, + direction: str, + cursor: str | None, + ) -> tuple[list[dict[str, Any]], bool]: + """Issue a single Graph ``/chats/{chat_id}/messages`` page and report has_more. + + Backward direction reverses the result so callers always see chronological + order; cursor filter clause is ``gt`` for forward, ``lt`` for backward. + """ + order_by = "createdDateTime asc" if direction == "forward" else "createdDateTime desc" + filter_op = "gt" if direction == "forward" else "lt" + params: dict[str, Any] = {"$top": limit, "$orderby": order_by} + if cursor: + params["$filter"] = f"createdDateTime {filter_op} {cursor}" + graph_messages = await self._graph_list_chat_messages(chat_id, params) + if direction != "forward": + graph_messages.reverse() + return graph_messages, len(graph_messages) >= limit + async def _graph_list_channel_messages( self, team_id: str, diff --git a/src/chat_sdk/adapters/teams/types.py b/src/chat_sdk/adapters/teams/types.py index cda4e2e..7e6f416 100644 --- a/src/chat_sdk/adapters/teams/types.py +++ b/src/chat_sdk/adapters/teams/types.py @@ -99,11 +99,37 @@ class TeamsThreadId: # ============================================================================= -class TeamsChannelContext(TypedDict): - """Teams channel context extracted from activity.channelData.""" +class TeamsChannelContext(TypedDict, total=False): + """Teams channel context extracted from activity.channelData. + + The ``type`` discriminator is optional for backwards-compatibility: + cached entries written before vercel/chat#403 omit it, and downstream + code treats a missing ``type`` as ``"channel"``. + """ channel_id: str team_id: str + type: str # Literal["channel"] when present + + +class TeamsDmContext(TypedDict): + """Teams DM context with the resolved Microsoft Graph chat ID. + + Bot Framework hands out opaque DM conversation IDs (e.g. + ``a:1xWhatever``) which are *not* accepted by Graph's + ``/chats/{chat-id}/messages`` endpoint. The canonical Graph chat ID + for a 1:1 DM is ``19:{userAadId}_{botId}@unq.gbl.spaces`` — derive + and cache it from the incoming activity's ``from.aadObjectId``. + """ + + graph_chat_id: str + type: str # Literal["dm"] + + +# Discriminated union for Microsoft Graph API resolution context. +# Group chats are not represented — their conversation ID works as-is +# with Graph's chat endpoints. +TeamsGraphContext = TeamsChannelContext | TeamsDmContext # ============================================================================= diff --git a/src/chat_sdk/thread.py b/src/chat_sdk/thread.py index 24eba7b..6154377 100644 --- a/src/chat_sdk/thread.py +++ b/src/chat_sdk/thread.py @@ -141,6 +141,48 @@ def _is_async_iterable(value: Any) -> bool: return value is not None and hasattr(value, "__aiter__") +def _extract_slack_recipient_team_id(raw: Any) -> str | None: + """Resolve the Slack workspace ID from a ``currentMessage.raw`` payload. + + Slack carries the workspace ID in different shapes depending on the + webhook envelope: + + - Message events (``message`` / ``app_mention``): top-level ``team_id`` + (string) or ``team`` (string). + - Interactive payloads (``block_actions``, ``view_submission``, etc.): + nested ``team.id`` on a ``team`` object, with ``user.team_id`` as a + final fallback when ``team`` is missing entirely. + + Upstream parity: vercel/chat#330. The previous extraction + (``raw.get("team_id") or raw.get("team")``) returned the entire ``team`` + dict for block_actions, which then traveled to the Slack adapter as the + ``recipient_team_id`` and either crashed downstream API calls or sent + them to the wrong workspace. + """ + if not isinstance(raw, dict): + return None + + team_id = raw.get("team_id") + if isinstance(team_id, str) and team_id: + return team_id + + team = raw.get("team") + if isinstance(team, str) and team: + return team + if isinstance(team, dict): + nested = team.get("id") + if isinstance(nested, str) and nested: + return nested + + user = raw.get("user") + if isinstance(user, dict): + user_team_id = user.get("team_id") + if isinstance(user_team_id, str) and user_team_id: + return user_team_id + + return None + + def _extract_message_content( message: AdapterPostableMessage, ) -> tuple[str, FormattedContent, list[Attachment]]: @@ -615,9 +657,13 @@ async def _handle_stream( options = StreamOptions() if self._current_message is not None: options.recipient_user_id = self._current_message.author.user_id - raw = self._current_message.raw - if isinstance(raw, dict): - options.recipient_team_id = raw.get("team_id") or raw.get("team") + # recipient_team_id is only consumed by the Slack adapter; other + # adapters ignore it. Slack carries the workspace ID in different + # shapes depending on webhook type — message events use the + # top-level ``team_id`` / ``team`` (string), block_actions / view + # payloads use ``team.id`` (object), with ``user.team_id`` as a + # final fallback. Upstream parity: vercel/chat#330. + options.recipient_team_id = _extract_slack_recipient_team_id(self._current_message.raw) # Merge caller-supplied StreamingPlan options on top. Explicit fields win. if extra_options is not None: diff --git a/tests/test_coverage_gaps.py b/tests/test_coverage_gaps.py index 5fbf098..6872434 100644 --- a/tests/test_coverage_gaps.py +++ b/tests/test_coverage_gaps.py @@ -339,7 +339,7 @@ async def test_fetch_messages_dm_chat(self): # Mock the Graph API methods adapter._get_graph_token = AsyncMock(return_value="fake-token") - adapter._get_channel_context = AsyncMock(return_value=None) + adapter._get_graph_context = AsyncMock(return_value=None) fake_graph_messages = [ { diff --git a/tests/test_teams_coverage.py b/tests/test_teams_coverage.py index b857325..5600f6a 100644 --- a/tests/test_teams_coverage.py +++ b/tests/test_teams_coverage.py @@ -1537,14 +1537,14 @@ def test_default_file_type(self): # --------------------------------------------------------------------------- -# _get_channel_context edge cases +# _get_graph_context edge cases # --------------------------------------------------------------------------- -class TestGetChannelContext: +class TestGetGraphContext: async def test_no_chat(self): adapter = _make_adapter(logger=_make_logger()) - result = await adapter._get_channel_context("19:abc") + result = await adapter._get_graph_context("19:abc") assert result is None async def test_no_state(self): @@ -1552,7 +1552,7 @@ async def test_no_state(self): chat = MagicMock() chat.get_state = MagicMock(return_value=None) await adapter.initialize(chat) - result = await adapter._get_channel_context("19:abc") + result = await adapter._get_graph_context("19:abc") assert result is None async def test_invalid_json_in_cache(self): @@ -1561,7 +1561,7 @@ async def test_invalid_json_in_cache(self): state._cache["teams:channelContext:19:abc"] = "not valid json {" chat = _make_mock_chat(state) await adapter.initialize(chat) - result = await adapter._get_channel_context("19:abc") + result = await adapter._get_graph_context("19:abc") assert result is None async def test_valid_context_from_cache(self): @@ -1570,11 +1570,311 @@ async def test_valid_context_from_cache(self): state._cache["teams:channelContext:19:abc"] = json.dumps({"team_id": "t1", "channel_id": "c1"}) chat = _make_mock_chat(state) await adapter.initialize(chat) - result = await adapter._get_channel_context("19:abc") + result = await adapter._get_graph_context("19:abc") assert result is not None assert result["team_id"] == "t1" +# --------------------------------------------------------------------------- +# vercel/chat#403 — DM conversation IDs for Microsoft Graph API +# --------------------------------------------------------------------------- + + +class TestGraphDmConversationIdResolution: + """Regression tests for vercel/chat#403. + + Bot Framework hands out opaque DM conversation IDs (e.g. + ``a:1xWhatever``) which the Graph ``/chats/{chat-id}/messages`` + endpoint rejects with 404. The fix caches the user's AAD object ID + on every incoming activity and resolves the canonical Graph chat ID + (``19:{aadId}_{botId}@unq.gbl.spaces``) before issuing Graph calls. + """ + + def test_chat_id_from_context_uses_dm_graph_chat_id(self): + """What to fix if this fails: ``_chat_id_from_context`` must return + the cached ``graph_chat_id`` for any DM context. Returning the raw + Bot Framework conversation ID instead reproduces the upstream bug + where Graph calls 404 for every DM. + """ + result = TeamsAdapter._chat_id_from_context( + {"type": "dm", "graph_chat_id": "19:user-aad-id_bot-id@unq.gbl.spaces"}, + "a:opaque-conversation-id", + ) + assert result == "19:user-aad-id_bot-id@unq.gbl.spaces" + + def test_chat_id_from_context_uses_raw_id_for_no_context(self): + """What to fix if this fails: group chats (no cached context) must + fall back to the raw conversation ID — they work as-is with Graph. + """ + result = TeamsAdapter._chat_id_from_context(None, "19:group-chat@thread.v2") + assert result == "19:group-chat@thread.v2" + + def test_chat_id_from_context_uses_raw_id_for_channel_context(self): + """What to fix if this fails: channel contexts go through the + ``/teams/{team-id}/channels/...`` path; if a channel context ever + leaks into ``_chat_id_from_context``, fall back to the raw ID + rather than the (absent) ``graph_chat_id``. + """ + result = TeamsAdapter._chat_id_from_context( + {"type": "channel", "team_id": "team-id", "channel_id": "channel-id"}, + "19:channel@thread.tacv2", + ) + assert result == "19:channel@thread.tacv2" + + async def test_cache_user_context_caches_dm_context_with_graph_chat_id(self): + """An incoming DM activity must cache a ``TeamsDmContext`` keyed by + the opaque base conversation ID, with ``graph_chat_id`` derived + from ``from.aadObjectId`` and the bot's app_id. + + What to fix if this fails: ``_cache_user_context`` must cache the + DM context whenever ``from.aadObjectId`` is present and the + conversation ID is *not* a channel/group chat (i.e. doesn't start + with ``19:``). Without this cache, ``fetch_messages`` falls back + to the raw Bot Framework ID and Graph returns 404. + """ + adapter = _make_adapter(logger=_make_logger(), app_id="bot-app-id") + state = _make_mock_state() + chat = _make_mock_chat(state) + await adapter.initialize(chat) + + activity = { + "type": "message", + "from": {"id": "29:1xUser", "aadObjectId": "00000000-0000-0000-0000-aaaaaaaaaaaa"}, + "conversation": {"id": "a:opaque-dm-conversation-id"}, + "serviceUrl": "https://smba.trafficmanager.net/teams/", + } + await adapter._cache_user_context(activity) + + raw_context = state._cache.get("teams:channelContext:a:opaque-dm-conversation-id") + assert raw_context is not None, "DM context must be cached for Graph chat ID resolution" + ctx = json.loads(raw_context) + assert ctx["type"] == "dm" + assert ctx["graph_chat_id"] == "19:00000000-0000-0000-0000-aaaaaaaaaaaa_bot-app-id@unq.gbl.spaces" + + async def test_cache_user_context_skips_dm_for_channel_conversation(self): + """Channel conversations (``19:...``) must not cache a DM context + even when ``from.aadObjectId`` is present. + + What to fix if this fails: the conversation-ID prefix guard in + ``_cache_user_context`` is broken — a channel activity is leaking + into the DM path and would over-write the channel context. + """ + adapter = _make_adapter(logger=_make_logger(), app_id="bot-app-id") + state = _make_mock_state() + chat = _make_mock_chat(state) + await adapter.initialize(chat) + + activity = { + "type": "message", + "from": {"id": "29:1xUser", "aadObjectId": "00000000-0000-0000-0000-aaaaaaaaaaaa"}, + "conversation": {"id": "19:channel@thread.tacv2"}, + "serviceUrl": "https://smba.trafficmanager.net/teams/", + "channelData": { + "team": {"aadGroupId": "team-aad"}, + "channel": {"id": "19:channel@thread.tacv2"}, + }, + } + await adapter._cache_user_context(activity) + + raw_context = state._cache.get("teams:channelContext:19:channel@thread.tacv2") + assert raw_context is not None + ctx = json.loads(raw_context) + # Adversarial check: a DM-like channel ID (still starts with "19:") + # should resolve to the channel context, not a DM context. + assert ctx.get("type") != "dm" + assert ctx["team_id"] == "team-aad" + + async def test_cache_user_context_skips_dm_when_no_aad_object_id(self): + """No ``aadObjectId`` (e.g. bot-to-bot activity) means we cannot + derive the canonical Graph chat ID — skip caching the DM context + entirely rather than caching a malformed one. + """ + adapter = _make_adapter(logger=_make_logger(), app_id="bot-app-id") + state = _make_mock_state() + chat = _make_mock_chat(state) + await adapter.initialize(chat) + + activity = { + "type": "message", + "from": {"id": "28:bot-id"}, # no aadObjectId + "conversation": {"id": "a:opaque-dm"}, + "serviceUrl": "https://smba.trafficmanager.net/teams/", + } + await adapter._cache_user_context(activity) + assert state._cache.get("teams:channelContext:a:opaque-dm") is None + + async def test_cache_user_context_rejects_non_guid_aad_object_id(self): + """Defense-in-depth: ``aadObjectId`` values that aren't GUIDs must + not be formatted into the Graph chat ID. Bot Framework JWT + verification authenticates the activity envelope but does not + constrain ``from.aadObjectId`` shape; a malformed value containing + ``/`` ``?`` ``#`` ``:`` could otherwise inject into the Graph URL. + + What to fix if this fails: ``_cache_user_context`` lost its + ``_AAD_OBJECT_ID_PATTERN.fullmatch`` guard. Re-add it before the + DM context is written. + """ + adapter = _make_adapter(logger=_make_logger(), app_id="bot-app-id") + state = _make_mock_state() + chat = _make_mock_chat(state) + await adapter.initialize(chat) + + # Various malformed shapes — a path-injection attempt, missing + # hyphens, too short, non-hex characters. + for bogus in [ + "user-aad-id", # not GUID-shaped + "../etc/passwd", # path traversal + "00000000-0000-0000-0000-aaaa/messages", # injection + "00000000000000000000000000000000", # right length, no hyphens + "ZZZZZZZZ-ZZZZ-ZZZZ-ZZZZ-ZZZZZZZZZZZZ", # non-hex + "00000000-0000-0000-0000", # too short + ]: + activity = { + "type": "message", + "from": {"id": "29:1xUser", "aadObjectId": bogus}, + "conversation": {"id": f"a:opaque-{bogus}"}, + "serviceUrl": "https://smba.trafficmanager.net/teams/", + } + await adapter._cache_user_context(activity) + cached = state._cache.get(f"teams:channelContext:a:opaque-{bogus}") + assert cached is None, ( + f"_cache_user_context cached a DM context with malformed " + f"aadObjectId={bogus!r}; the GUID guard must reject this " + f"before formatting it into a Graph chat ID" + ) + + async def test_fetch_messages_uses_legacy_cache_shape_for_channel(self): + """End-to-end backwards-compat: cached entries written before + vercel/chat#403 lack the ``type`` discriminator. They must still + route through ``fetch_messages``'s channel branch (treated as + ``type=channel`` by ``_chat_id_from_context``), not get + misclassified or crash on the missing key. + + What to fix if this fails: ``_chat_id_from_context`` or + ``_get_graph_context`` lost the legacy-shape fallthrough. The + contract is "absent ``type`` ⇒ treated as channel". + """ + adapter = _make_adapter(logger=_make_logger(), app_id="bot-app-id") + state = _make_mock_state() + # Legacy shape — no ``type`` key, only ``team_id`` + ``channel_id``. + state._cache["teams:channelContext:19:legacy-channel@thread.tacv2"] = json.dumps( + {"team_id": "team-aad", "channel_id": "19:legacy-channel@thread.tacv2"} + ) + chat = _make_mock_chat(state) + await adapter.initialize(chat) + + # The legacy cache should route to the channel-messages path — + # which uses ``team_id`` + ``channel_id``, NOT the chat-messages + # path (which would be wrong for a channel and would call Graph + # with the raw conversation ID). Confirm by hooking the channel + # endpoint and asserting it's called. + async def fake_channel_list(team_id: str, channel_id: str, limit: int = 50): + return [{"id": "1234", "from": {"user": {"id": "u1", "displayName": "x"}}, "body": {"content": "hi"}}] + + adapter._graph_list_channel_messages = fake_channel_list # type: ignore[assignment] + + ctx = await adapter._get_graph_context("19:legacy-channel@thread.tacv2") + assert ctx is not None, "Legacy cache shape did not load" + # ``_chat_id_from_context`` returns the raw base id when there's + # no ``graph_chat_id`` (channel context shape). The crucial assert + # is that we treat this as *not-DM* (no Graph 19:...@unq.gbl.spaces + # construction). + chat_id = adapter._chat_id_from_context(ctx, "19:legacy-channel@thread.tacv2") + assert chat_id == "19:legacy-channel@thread.tacv2", ( + f"_chat_id_from_context misclassified legacy cache entry as DM; " + f"got {chat_id!r}, expected the raw conversation ID. The " + f"absent-``type`` legacy shape must fall through to channel " + f"semantics." + ) + + async def test_fetch_messages_uses_graph_chat_id_for_dm(self): + """End-to-end: a cached DM context must redirect ``fetch_messages`` + to the canonical ``19:{aadId}_{botId}@unq.gbl.spaces`` chat ID, + not the opaque Bot Framework conversation ID. + + What to fix if this fails: ``fetch_messages`` is calling + ``_graph_list_chat_messages`` with the raw ``base_conversation_id`` + instead of the resolved Graph chat ID. Graph returns 404 in + production for every DM in this state. Confirm + ``_chat_id_from_context`` is invoked and its result threads through + the chat-id parameter. + """ + from chat_sdk.adapters.teams.adapter import TeamsAdapter as _TA + + adapter = _make_adapter(logger=_make_logger(), app_id="bot-app-id") + state = _make_mock_state() + # Cache the DM context as if a previous activity had landed. + state._cache["teams:channelContext:a:opaque-dm-id"] = json.dumps( + {"type": "dm", "graph_chat_id": "19:user-aad_bot-app-id@unq.gbl.spaces"} + ) + chat = _make_mock_chat(state) + await adapter.initialize(chat) + + called_with: list[str] = [] + + async def fake_list(chat_id: str, params: Any) -> list[dict[str, Any]]: + called_with.append(chat_id) + return [ + { + "id": "msg-1", + "createdDateTime": "2024-06-01T12:00:00Z", + "body": {"contentType": "text", "content": "Hi"}, + "from": {"user": {"id": "user-aad", "displayName": "Alice"}}, + } + ] + + adapter._graph_list_chat_messages = fake_list # type: ignore[method-assign] + adapter._get_graph_token = AsyncMock(return_value="t") # type: ignore[method-assign] + + tid = adapter.encode_thread_id( + TeamsThreadId( + conversation_id="a:opaque-dm-id", + service_url="https://smba.trafficmanager.net/teams/", + ) + ) + + result = await adapter.fetch_messages(tid) + + assert called_with == ["19:user-aad_bot-app-id@unq.gbl.spaces"], ( + "fetch_messages must dispatch DM Graph calls to the resolved " + "graph_chat_id, not the opaque Bot Framework conversation ID. " + f"Saw chat_id={called_with!r}." + ) + assert len(result.messages) == 1 + # Silence unused-import warning if the helper above is removed. + assert _TA is TeamsAdapter + + async def test_fetch_messages_falls_back_to_raw_id_when_no_dm_context(self): + """Pre-#403 behavior preservation: when no DM context is cached + (e.g. a group chat conversation, or an installation that hasn't + seen an activity yet), fall back to the raw conversation ID. This + keeps group chats working as before. + """ + adapter = _make_adapter(logger=_make_logger()) + state = _make_mock_state() # empty + chat = _make_mock_chat(state) + await adapter.initialize(chat) + + called_with: list[str] = [] + + async def fake_list(chat_id: str, params: Any) -> list[dict[str, Any]]: + called_with.append(chat_id) + return [] + + adapter._graph_list_chat_messages = fake_list # type: ignore[method-assign] + adapter._get_graph_token = AsyncMock(return_value="t") # type: ignore[method-assign] + + tid = adapter.encode_thread_id( + TeamsThreadId( + conversation_id="19:group-chat@thread.v2", + service_url="https://smba.trafficmanager.net/teams/", + ) + ) + await adapter.fetch_messages(tid) + + assert called_with == ["19:group-chat@thread.v2"] + + # --------------------------------------------------------------------------- # _extract_text_from_graph_message edge cases # --------------------------------------------------------------------------- diff --git a/tests/test_thread_faithful.py b/tests/test_thread_faithful.py index 693591d..d86c858 100644 --- a/tests/test_thread_faithful.py +++ b/tests/test_thread_faithful.py @@ -784,9 +784,67 @@ async def slow_stream() -> AsyncIterator[str]: open_count = final_md.count("**") assert open_count % 2 == 0 - # it("should pass stream options from current message context") + # it.each([...])("should pass stream options from Slack current message context via $label") + # Upstream parity: vercel/chat#330. Slack carries the workspace ID in + # different shapes depending on webhook envelope; the post() dispatcher + # must extract it from each. + # + # What to fix if this fails: see ``_extract_slack_recipient_team_id`` in + # ``src/chat_sdk/thread.py``. The interactive payload lookup must follow + # the order: top-level ``team_id`` → top-level ``team`` (string) → + # nested ``team.id`` (object) → ``user.team_id`` fallback. + @pytest.mark.parametrize( + ("label", "raw", "expected_team_id"), + [ + # Happy paths: each shape resolves to its expected team id. + ("team_id", {"team_id": "T123", "type": "app_mention"}, "T123"), + ("team string", {"team": "T234", "type": "message"}, "T234"), + ("team.id object", {"team": {"id": "T345"}, "type": "block_actions"}, "T345"), + ( + "user.team_id fallback", + {"type": "block_actions", "user": {"team_id": "T456"}}, + "T456", + ), + # Adversarial fall-throughs — each MUST cascade to the next + # resolution step rather than capturing a malformed value. + # See docs/SELF_REVIEW.md principle #1 (input sweep). + ( + "team dict missing id falls through to user.team_id", + {"team": {"domain": "acme"}, "user": {"team_id": "T567"}}, + "T567", + ), + ( + "empty team_id string falls through to team string", + {"team_id": "", "team": "T678"}, + "T678", + ), + ( + "empty team string falls through to team.id object", + {"team_id": "", "team": "", "user": {"team_id": "T789"}}, + "T789", + ), + # Final-fallback: nothing matches → None propagates as + # ``recipient_team_id=None``. The Slack adapter raises a clear + # multi-workspace error rather than calling Slack with the + # wrong workspace. + ("non-dict raw returns None", "not a dict", None), + ("empty dict returns None", {}, None), + ( + "non-string user.team_id returns None", + {"user": {"team_id": 12345}}, + None, + ), + ( + "team dict with non-string id returns None", + {"team": {"id": 999}}, + None, + ), + ], + ) @pytest.mark.asyncio - async def test_should_pass_stream_options_from_current_message_context(self): + async def test_should_pass_stream_options_from_current_message_context( + self, label: str, raw: dict[str, Any], expected_team_id: str + ): adapter = create_mock_adapter() state = create_mock_state() @@ -806,7 +864,7 @@ async def mock_stream(thread_id: str, text_stream: Any, options: Any = None) -> thread_id="slack:C123:1234.5678", text="test", formatted={"type": "root", "children": []}, - raw={"team_id": "T123"}, + raw=raw, author=Author( user_id="U456", user_name="user", @@ -825,7 +883,131 @@ async def mock_stream(thread_id: str, text_stream: Any, options: Any = None) -> assert len(stream_call_args) == 1 options = stream_call_args[0][2] assert options.recipient_user_id == "U456" - assert options.recipient_team_id == "T123" + assert options.recipient_team_id == expected_team_id, ( + f"label={label!r}: expected {expected_team_id!r} but got " + f"{options.recipient_team_id!r}. The interactive-payload " + f"team_id extraction (vercel/chat#330) must walk team_id → " + f"team (string) → team.id (object) → user.team_id." + ) + + # it("should forward structured stream chunks to adapter.stream from an action-created thread") + # Upstream parity: vercel/chat#330. Verifies that a block_actions + # context (where ``raw.team`` is an object) still routes a *structured* + # stream (text + ``task_update`` chunks) into ``adapter.stream``, not + # the text-only fallback path. Before #330, an undefined + # ``recipient_team_id`` would still call ``adapter.stream`` but the + # adapter would fail to authenticate; here we assert the chunks land + # untouched and ``recipient_team_id`` is set. + @pytest.mark.asyncio + async def test_should_forward_structured_stream_chunks_to_adapter_stream_from_an_action_created_thread(self): + adapter = create_mock_adapter() + state = create_mock_state() + + forwarded_chunks: list[Any] = [] + captured_options: list[Any] = [] + + async def mock_stream(thread_id: str, text_stream: Any, options: Any = None) -> RawMessage: + captured_options.append(options) + async for chunk in text_stream: + forwarded_chunks.append(chunk) + return RawMessage(id="msg-stream", thread_id=thread_id, raw="Hello") + + adapter.stream = mock_stream # type: ignore[attr-defined] + + action_msg = Message( + id="action-msg", + thread_id="slack:C123:1234.5678", + text="", + formatted={"type": "root", "children": []}, + # block_actions shape with team as an object — the case #330 broke. + raw={"team": {"domain": "workspace", "id": "T123"}, "type": "block_actions"}, + author=Author( + user_id="U456", + user_name="user", + full_name="Test User", + is_bot=False, + is_me=False, + ), + metadata=MessageMetadata(date_sent=datetime.now(tz=timezone.utc), edited=False), + attachments=[], + ) + thread = _make_thread(adapter, state, current_message=action_msg) + + task_chunk = TaskUpdateChunk(id="task-1", status="pending", title="Thinking", type="task_update") + + async def structured_stream() -> AsyncIterator[Any]: + yield "Picking option..." + yield task_chunk + + await thread.post(structured_stream()) + + assert len(captured_options) == 1 + assert captured_options[0].recipient_team_id == "T123" + # Both the text and the task_update chunk reached adapter.stream. + assert "Picking option..." in forwarded_chunks + assert task_chunk in forwarded_chunks + + # vercel/chat#330 regression — concurrent block_actions payloads with + # different team.id values must not cross-contaminate. Hazard #6 + # (ContextVar boundaries): even though team_id flows through the + # per-thread ``currentMessage`` (not a ContextVar), running two posts + # concurrently still verifies that no module-level cache or shared + # adapter state leaks the team between requests. + # + # What to fix if this fails: a regression in + # ``_extract_slack_recipient_team_id`` or in how ThreadImpl reads + # ``self._current_message.raw`` is letting one request see another's + # team. Check for any ContextVar / module-global the extraction relies + # on. + @pytest.mark.asyncio + async def test_concurrent_block_actions_team_ids_do_not_cross_contaminate(self): + captured: list[tuple[str, str | None]] = [] + captured_lock = asyncio.Lock() + + async def _make_post(team_id: str, user_id: str) -> None: + adapter = create_mock_adapter() + state = create_mock_state() + + async def mock_stream(thread_id: str, text_stream: Any, options: Any = None) -> RawMessage: + # Yield to interleave with the sibling task before recording. + await asyncio.sleep(0) + async for _ in text_stream: + pass + async with captured_lock: + captured.append((user_id, options.recipient_team_id)) + return RawMessage(id=f"msg-{team_id}", thread_id=thread_id, raw="ok") + + adapter.stream = mock_stream # type: ignore[attr-defined] + + current_msg = Message( + id=f"original-{team_id}", + thread_id="slack:C123:1234.5678", + text="", + formatted={"type": "root", "children": []}, + # block_actions shape (team is an object). + raw={"type": "block_actions", "team": {"id": team_id}}, + author=Author( + user_id=user_id, + user_name=user_id, + full_name=user_id, + is_bot=False, + is_me=False, + ), + metadata=MessageMetadata(date_sent=datetime.now(tz=timezone.utc), edited=False), + attachments=[], + ) + + thread = _make_thread(adapter, state, current_message=current_msg) + await thread.post(_create_text_stream(["hi"])) + + await asyncio.gather( + _make_post("T_AAA", "U_AAA"), + _make_post("T_BBB", "U_BBB"), + ) + + # Each request must see its own team_id, not the other's. + captured_map = dict(captured) + assert captured_map == {"U_AAA": "T_AAA", "U_BBB": "T_BBB"} # it("should pass StreamingPlan PostableObject options to adapter.stream") @pytest.mark.asyncio